Made early stopping conditional and turned it off
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -123,7 +123,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
weights_prefix: str,
|
||||
channels: int = 1, zsize: int = 32, lr: float = 0.002,
|
||||
batch_size: int = 128, train_epoch: int = 80,
|
||||
verbose: bool = True) -> None:
|
||||
verbose: bool = True, early_stopping: bool = False) -> None:
|
||||
"""
|
||||
Trains AAE for given data set.
|
||||
|
||||
@ -154,6 +154,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
batch_size: the size of each batch (default: 128)
|
||||
train_epoch: number of epochs to train (default: 80)
|
||||
verbose: if True prints train progress info to console (default: True)
|
||||
early_stopping: if True the early stopping mechanic is enabled (default: False)
|
||||
"""
|
||||
|
||||
# non-preserved tensors
|
||||
@ -243,44 +244,45 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
|
||||
# check for improvements in error reduction - otherwise early stopping
|
||||
strike = False
|
||||
total_strike = False
|
||||
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
||||
outputs['xd_loss'] + outputs['zd_loss']
|
||||
if total_loss < total_lowest_loss:
|
||||
total_lowest_loss = total_loss
|
||||
elif total_loss > TOTAL_LOSS_GRACE_CAP:
|
||||
total_strike = True
|
||||
if outputs['encoder_loss'] < encoder_lowest_loss:
|
||||
encoder_lowest_loss = outputs['encoder_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['decoder_loss'] < decoder_lowest_loss:
|
||||
decoder_lowest_loss = outputs['decoder_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
|
||||
enc_dec_lowest_loss = outputs['enc_dec_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['xd_loss'] < xd_lowest_loss:
|
||||
xd_lowest_loss = outputs['xd_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['zd_loss'] < zd_lowest_loss:
|
||||
zd_lowest_loss = outputs['zd_loss']
|
||||
else:
|
||||
strike = True
|
||||
if early_stopping:
|
||||
strike = False
|
||||
total_strike = False
|
||||
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
||||
outputs['xd_loss'] + outputs['zd_loss']
|
||||
if total_loss < total_lowest_loss:
|
||||
total_lowest_loss = total_loss
|
||||
elif total_loss > TOTAL_LOSS_GRACE_CAP:
|
||||
total_strike = True
|
||||
if outputs['encoder_loss'] < encoder_lowest_loss:
|
||||
encoder_lowest_loss = outputs['encoder_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['decoder_loss'] < decoder_lowest_loss:
|
||||
decoder_lowest_loss = outputs['decoder_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
|
||||
enc_dec_lowest_loss = outputs['enc_dec_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['xd_loss'] < xd_lowest_loss:
|
||||
xd_lowest_loss = outputs['xd_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['zd_loss'] < zd_lowest_loss:
|
||||
zd_lowest_loss = outputs['zd_loss']
|
||||
else:
|
||||
strike = True
|
||||
|
||||
if strike and total_strike:
|
||||
grace_period -= 1
|
||||
elif strike:
|
||||
pass
|
||||
else:
|
||||
grace_period = GRACE
|
||||
if strike and total_strike:
|
||||
grace_period -= 1
|
||||
elif strike:
|
||||
pass
|
||||
else:
|
||||
grace_period = GRACE
|
||||
|
||||
if grace_period == 0:
|
||||
break
|
||||
if grace_period == 0:
|
||||
break
|
||||
|
||||
if verbose:
|
||||
if grace_period > 0:
|
||||
|
||||
Reference in New Issue
Block a user