From 75d7c769c742b49665587534607c04f2f33f72ab Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 18:22:25 +0100 Subject: [PATCH] Made early stopping conditional and turned it off Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 78 ++++++++++++------------ 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index a5743cc..aeb176c 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -123,7 +123,7 @@ def train(dataset: tf.data.Dataset, iteration: int, weights_prefix: str, channels: int = 1, zsize: int = 32, lr: float = 0.002, batch_size: int = 128, train_epoch: int = 80, - verbose: bool = True) -> None: + verbose: bool = True, early_stopping: bool = False) -> None: """ Trains AAE for given data set. @@ -154,6 +154,7 @@ def train(dataset: tf.data.Dataset, iteration: int, batch_size: the size of each batch (default: 128) train_epoch: number of epochs to train (default: 80) verbose: if True prints train progress info to console (default: True) + early_stopping: if True the early stopping mechanic is enabled (default: False) """ # non-preserved tensors @@ -243,44 +244,45 @@ def train(dataset: tf.data.Dataset, iteration: int, checkpoint.save(checkpoint_prefix) # check for improvements in error reduction - otherwise early stopping - strike = False - total_strike = False - total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \ - outputs['xd_loss'] + outputs['zd_loss'] - if total_loss < total_lowest_loss: - total_lowest_loss = total_loss - elif total_loss > TOTAL_LOSS_GRACE_CAP: - total_strike = True - if outputs['encoder_loss'] < encoder_lowest_loss: - encoder_lowest_loss = outputs['encoder_loss'] - else: - strike = True - if outputs['decoder_loss'] < decoder_lowest_loss: - decoder_lowest_loss = outputs['decoder_loss'] - else: - strike = True - if outputs['enc_dec_loss'] < enc_dec_lowest_loss: - enc_dec_lowest_loss = outputs['enc_dec_loss'] - else: - strike = True - if outputs['xd_loss'] < xd_lowest_loss: - xd_lowest_loss = outputs['xd_loss'] - else: - strike = True - if outputs['zd_loss'] < zd_lowest_loss: - zd_lowest_loss = outputs['zd_loss'] - else: - strike = True - - if strike and total_strike: - grace_period -= 1 - elif strike: - pass - else: - grace_period = GRACE + if early_stopping: + strike = False + total_strike = False + total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \ + outputs['xd_loss'] + outputs['zd_loss'] + if total_loss < total_lowest_loss: + total_lowest_loss = total_loss + elif total_loss > TOTAL_LOSS_GRACE_CAP: + total_strike = True + if outputs['encoder_loss'] < encoder_lowest_loss: + encoder_lowest_loss = outputs['encoder_loss'] + else: + strike = True + if outputs['decoder_loss'] < decoder_lowest_loss: + decoder_lowest_loss = outputs['decoder_loss'] + else: + strike = True + if outputs['enc_dec_loss'] < enc_dec_lowest_loss: + enc_dec_lowest_loss = outputs['enc_dec_loss'] + else: + strike = True + if outputs['xd_loss'] < xd_lowest_loss: + xd_lowest_loss = outputs['xd_loss'] + else: + strike = True + if outputs['zd_loss'] < zd_lowest_loss: + zd_lowest_loss = outputs['zd_loss'] + else: + strike = True - if grace_period == 0: - break + if strike and total_strike: + grace_period -= 1 + elif strike: + pass + else: + grace_period = GRACE + + if grace_period == 0: + break if verbose: if grace_period > 0: