Removed obsolete log outputs
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -226,9 +226,7 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
print((
|
||||
f"[{_epoch + 1:d}/{train_epoch:d}] - "
|
||||
f"train time: {outputs['per_epoch_time']:.2f}, "
|
||||
f"Decoder loss: {outputs['decoder_loss']:.3f}, "
|
||||
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}, "
|
||||
f"Encoder loss: {outputs['encoder_loss']:.3f}"
|
||||
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}"
|
||||
))
|
||||
|
||||
# save sample image summary
|
||||
@ -249,31 +247,15 @@ def train_simple(dataset: tf.data.Dataset,
|
||||
strike = False
|
||||
total_strike = False
|
||||
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
||||
outputs['xd_loss'] + outputs['zd_loss']
|
||||
outputs['xd_loss'] + outputs['zd_loss']
|
||||
if total_loss < total_lowest_loss:
|
||||
total_lowest_loss = total_loss
|
||||
elif total_loss > TOTAL_LOSS_GRACE_CAP:
|
||||
total_strike = True
|
||||
if outputs['encoder_loss'] < encoder_lowest_loss:
|
||||
encoder_lowest_loss = outputs['encoder_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['decoder_loss'] < decoder_lowest_loss:
|
||||
decoder_lowest_loss = outputs['decoder_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
|
||||
enc_dec_lowest_loss = outputs['enc_dec_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['xd_loss'] < xd_lowest_loss:
|
||||
xd_lowest_loss = outputs['xd_loss']
|
||||
else:
|
||||
strike = True
|
||||
if outputs['zd_loss'] < zd_lowest_loss:
|
||||
zd_lowest_loss = outputs['zd_loss']
|
||||
else:
|
||||
strike = True
|
||||
|
||||
if strike and total_strike:
|
||||
grace_period -= 1
|
||||
@ -309,8 +291,6 @@ def _train_one_epoch_simple(epoch: int,
|
||||
epoch_var.assign(epoch)
|
||||
epoch_start_time = time.time()
|
||||
# define loss variables
|
||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
|
||||
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
|
||||
|
||||
# update learning rate
|
||||
@ -343,8 +323,6 @@ def _train_one_epoch_simple(epoch: int,
|
||||
|
||||
# final losses of epoch
|
||||
outputs = {
|
||||
'decoder_loss': decoder_loss_avg.result(False),
|
||||
'encoder_loss': encoder_loss_avg.result(False),
|
||||
'enc_dec_loss': enc_dec_loss_avg.result(False),
|
||||
'per_epoch_time': per_epoch_time,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user