From 855d8764d025d76f5c3c0a0dc3cb57553f128037 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Sat, 30 Mar 2019 12:12:27 +0100 Subject: [PATCH] Removed obsolete log outputs Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 26 ++---------------------- 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 0172e67..9e39c63 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -226,9 +226,7 @@ def train_simple(dataset: tf.data.Dataset, print(( f"[{_epoch + 1:d}/{train_epoch:d}] - " f"train time: {outputs['per_epoch_time']:.2f}, " - f"Decoder loss: {outputs['decoder_loss']:.3f}, " - f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}, " - f"Encoder loss: {outputs['encoder_loss']:.3f}" + f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}" )) # save sample image summary @@ -249,31 +247,15 @@ def train_simple(dataset: tf.data.Dataset, strike = False total_strike = False total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \ - outputs['xd_loss'] + outputs['zd_loss'] + outputs['xd_loss'] + outputs['zd_loss'] if total_loss < total_lowest_loss: total_lowest_loss = total_loss elif total_loss > TOTAL_LOSS_GRACE_CAP: total_strike = True - if outputs['encoder_loss'] < encoder_lowest_loss: - encoder_lowest_loss = outputs['encoder_loss'] - else: - strike = True - if outputs['decoder_loss'] < decoder_lowest_loss: - decoder_lowest_loss = outputs['decoder_loss'] - else: - strike = True if outputs['enc_dec_loss'] < enc_dec_lowest_loss: enc_dec_lowest_loss = outputs['enc_dec_loss'] else: strike = True - if outputs['xd_loss'] < xd_lowest_loss: - xd_lowest_loss = outputs['xd_loss'] - else: - strike = True - if outputs['zd_loss'] < zd_lowest_loss: - zd_lowest_loss = outputs['zd_loss'] - else: - strike = True if strike and total_strike: grace_period -= 1 @@ -309,8 +291,6 @@ def _train_one_epoch_simple(epoch: int, epoch_var.assign(epoch) epoch_start_time = time.time() # define loss variables - encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) - decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32) enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32) # update learning rate @@ -343,8 +323,6 @@ def _train_one_epoch_simple(epoch: int, # final losses of epoch outputs = { - 'decoder_loss': decoder_loss_avg.result(False), - 'encoder_loss': encoder_loss_avg.result(False), 'enc_dec_loss': enc_dec_loss_avg.result(False), 'per_epoch_time': per_epoch_time, }