Made early stopping conditional and turned it off

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 18:22:25 +01:00
parent ea1b4943cf
commit 75d7c769c7

View File

@ -123,7 +123,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
weights_prefix: str,
channels: int = 1, zsize: int = 32, lr: float = 0.002,
batch_size: int = 128, train_epoch: int = 80,
verbose: bool = True) -> None:
verbose: bool = True, early_stopping: bool = False) -> None:
"""
Trains AAE for given data set.
@ -154,6 +154,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
batch_size: the size of each batch (default: 128)
train_epoch: number of epochs to train (default: 80)
verbose: if True prints train progress info to console (default: True)
early_stopping: if True the early stopping mechanic is enabled (default: False)
"""
# non-preserved tensors
@ -243,44 +244,45 @@ def train(dataset: tf.data.Dataset, iteration: int,
checkpoint.save(checkpoint_prefix)
# check for improvements in error reduction - otherwise early stopping
strike = False
total_strike = False
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
outputs['xd_loss'] + outputs['zd_loss']
if total_loss < total_lowest_loss:
total_lowest_loss = total_loss
elif total_loss > TOTAL_LOSS_GRACE_CAP:
total_strike = True
if outputs['encoder_loss'] < encoder_lowest_loss:
encoder_lowest_loss = outputs['encoder_loss']
else:
strike = True
if outputs['decoder_loss'] < decoder_lowest_loss:
decoder_lowest_loss = outputs['decoder_loss']
else:
strike = True
if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
enc_dec_lowest_loss = outputs['enc_dec_loss']
else:
strike = True
if outputs['xd_loss'] < xd_lowest_loss:
xd_lowest_loss = outputs['xd_loss']
else:
strike = True
if outputs['zd_loss'] < zd_lowest_loss:
zd_lowest_loss = outputs['zd_loss']
else:
strike = True
if early_stopping:
strike = False
total_strike = False
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
outputs['xd_loss'] + outputs['zd_loss']
if total_loss < total_lowest_loss:
total_lowest_loss = total_loss
elif total_loss > TOTAL_LOSS_GRACE_CAP:
total_strike = True
if outputs['encoder_loss'] < encoder_lowest_loss:
encoder_lowest_loss = outputs['encoder_loss']
else:
strike = True
if outputs['decoder_loss'] < decoder_lowest_loss:
decoder_lowest_loss = outputs['decoder_loss']
else:
strike = True
if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
enc_dec_lowest_loss = outputs['enc_dec_loss']
else:
strike = True
if outputs['xd_loss'] < xd_lowest_loss:
xd_lowest_loss = outputs['xd_loss']
else:
strike = True
if outputs['zd_loss'] < zd_lowest_loss:
zd_lowest_loss = outputs['zd_loss']
else:
strike = True
if strike and total_strike:
grace_period -= 1
elif strike:
pass
else:
grace_period = GRACE
if strike and total_strike:
grace_period -= 1
elif strike:
pass
else:
grace_period = GRACE
if grace_period == 0:
break
if grace_period == 0:
break
if verbose:
if grace_period > 0: