Made early stopping conditional and turned it off
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -123,7 +123,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
weights_prefix: str,
|
||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||
batch_size: int = 128, train_epoch: int = 80,
|
||||
verbose: bool = True) -> None:
|
||||
verbose: bool = True, early_stopping: bool = False) -> None:
|
||||
"""
|
||||
Trains AAE for given data set.
|
||||
|
||||
@ -154,6 +154,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
batch_size: the size of each batch (default: 128)
|
||||
train_epoch: number of epochs to train (default: 80)
|
||||
verbose: if True prints train progress info to console (default: True)
|
||||
early_stopping: if True the early stopping mechanic is enabled (default: False)
|
||||
"""
|
||||
|
||||
# non-preserved tensors
|
||||
@ -243,6 +244,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
|
||||
# check for improvements in error reduction - otherwise early stopping
|
||||
if early_stopping:
|
||||
strike = False
|
||||
total_strike = False
|
||||
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
||||
|
||||
Reference in New Issue
Block a user