Removed obsolete log outputs

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-03-30 12:12:27 +01:00
parent 8ef7d5f999
commit 855d8764d0

View File

@ -226,9 +226,7 @@ def train_simple(dataset: tf.data.Dataset,
print(( print((
f"[{_epoch + 1:d}/{train_epoch:d}] - " f"[{_epoch + 1:d}/{train_epoch:d}] - "
f"train time: {outputs['per_epoch_time']:.2f}, " f"train time: {outputs['per_epoch_time']:.2f}, "
f"Decoder loss: {outputs['decoder_loss']:.3f}, " f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}"
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}, "
f"Encoder loss: {outputs['encoder_loss']:.3f}"
)) ))
# save sample image summary # save sample image summary
@ -249,31 +247,15 @@ def train_simple(dataset: tf.data.Dataset,
strike = False strike = False
total_strike = False total_strike = False
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \ total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
outputs['xd_loss'] + outputs['zd_loss'] outputs['xd_loss'] + outputs['zd_loss']
if total_loss < total_lowest_loss: if total_loss < total_lowest_loss:
total_lowest_loss = total_loss total_lowest_loss = total_loss
elif total_loss > TOTAL_LOSS_GRACE_CAP: elif total_loss > TOTAL_LOSS_GRACE_CAP:
total_strike = True total_strike = True
if outputs['encoder_loss'] < encoder_lowest_loss:
encoder_lowest_loss = outputs['encoder_loss']
else:
strike = True
if outputs['decoder_loss'] < decoder_lowest_loss:
decoder_lowest_loss = outputs['decoder_loss']
else:
strike = True
if outputs['enc_dec_loss'] < enc_dec_lowest_loss: if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
enc_dec_lowest_loss = outputs['enc_dec_loss'] enc_dec_lowest_loss = outputs['enc_dec_loss']
else: else:
strike = True strike = True
if outputs['xd_loss'] < xd_lowest_loss:
xd_lowest_loss = outputs['xd_loss']
else:
strike = True
if outputs['zd_loss'] < zd_lowest_loss:
zd_lowest_loss = outputs['zd_loss']
else:
strike = True
if strike and total_strike: if strike and total_strike:
grace_period -= 1 grace_period -= 1
@ -309,8 +291,6 @@ def _train_one_epoch_simple(epoch: int,
epoch_var.assign(epoch) epoch_var.assign(epoch)
epoch_start_time = time.time() epoch_start_time = time.time()
# define loss variables # define loss variables
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32) enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
# update learning rate # update learning rate
@ -343,8 +323,6 @@ def _train_one_epoch_simple(epoch: int,
# final losses of epoch # final losses of epoch
outputs = { outputs = {
'decoder_loss': decoder_loss_avg.result(False),
'encoder_loss': encoder_loss_avg.result(False),
'enc_dec_loss': enc_dec_loss_avg.result(False), 'enc_dec_loss': enc_dec_loss_avg.result(False),
'per_epoch_time': per_epoch_time, 'per_epoch_time': per_epoch_time,
} }