Made early stopping conditional and turned it off
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -123,7 +123,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
weights_prefix: str,
|
weights_prefix: str,
|
||||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||||
batch_size: int = 128, train_epoch: int = 80,
|
batch_size: int = 128, train_epoch: int = 80,
|
||||||
verbose: bool = True) -> None:
|
verbose: bool = True, early_stopping: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Trains AAE for given data set.
|
Trains AAE for given data set.
|
||||||
|
|
||||||
@ -154,6 +154,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
batch_size: the size of each batch (default: 128)
|
batch_size: the size of each batch (default: 128)
|
||||||
train_epoch: number of epochs to train (default: 80)
|
train_epoch: number of epochs to train (default: 80)
|
||||||
verbose: if True prints train progress info to console (default: True)
|
verbose: if True prints train progress info to console (default: True)
|
||||||
|
early_stopping: if True the early stopping mechanic is enabled (default: False)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# non-preserved tensors
|
# non-preserved tensors
|
||||||
@ -243,6 +244,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
checkpoint.save(checkpoint_prefix)
|
checkpoint.save(checkpoint_prefix)
|
||||||
|
|
||||||
# check for improvements in error reduction - otherwise early stopping
|
# check for improvements in error reduction - otherwise early stopping
|
||||||
|
if early_stopping:
|
||||||
strike = False
|
strike = False
|
||||||
total_strike = False
|
total_strike = False
|
||||||
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
||||||
|
|||||||
Reference in New Issue
Block a user