Removed obsolete log outputs
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -226,9 +226,7 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
print((
|
print((
|
||||||
f"[{_epoch + 1:d}/{train_epoch:d}] - "
|
f"[{_epoch + 1:d}/{train_epoch:d}] - "
|
||||||
f"train time: {outputs['per_epoch_time']:.2f}, "
|
f"train time: {outputs['per_epoch_time']:.2f}, "
|
||||||
f"Decoder loss: {outputs['decoder_loss']:.3f}, "
|
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}"
|
||||||
f"Encoder + Decoder loss: {outputs['enc_dec_loss']:.3f}, "
|
|
||||||
f"Encoder loss: {outputs['encoder_loss']:.3f}"
|
|
||||||
))
|
))
|
||||||
|
|
||||||
# save sample image summary
|
# save sample image summary
|
||||||
@ -249,31 +247,15 @@ def train_simple(dataset: tf.data.Dataset,
|
|||||||
strike = False
|
strike = False
|
||||||
total_strike = False
|
total_strike = False
|
||||||
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
total_loss = outputs['encoder_loss'] + outputs['decoder_loss'] + outputs['enc_dec_loss'] + \
|
||||||
outputs['xd_loss'] + outputs['zd_loss']
|
outputs['xd_loss'] + outputs['zd_loss']
|
||||||
if total_loss < total_lowest_loss:
|
if total_loss < total_lowest_loss:
|
||||||
total_lowest_loss = total_loss
|
total_lowest_loss = total_loss
|
||||||
elif total_loss > TOTAL_LOSS_GRACE_CAP:
|
elif total_loss > TOTAL_LOSS_GRACE_CAP:
|
||||||
total_strike = True
|
total_strike = True
|
||||||
if outputs['encoder_loss'] < encoder_lowest_loss:
|
|
||||||
encoder_lowest_loss = outputs['encoder_loss']
|
|
||||||
else:
|
|
||||||
strike = True
|
|
||||||
if outputs['decoder_loss'] < decoder_lowest_loss:
|
|
||||||
decoder_lowest_loss = outputs['decoder_loss']
|
|
||||||
else:
|
|
||||||
strike = True
|
|
||||||
if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
|
if outputs['enc_dec_loss'] < enc_dec_lowest_loss:
|
||||||
enc_dec_lowest_loss = outputs['enc_dec_loss']
|
enc_dec_lowest_loss = outputs['enc_dec_loss']
|
||||||
else:
|
else:
|
||||||
strike = True
|
strike = True
|
||||||
if outputs['xd_loss'] < xd_lowest_loss:
|
|
||||||
xd_lowest_loss = outputs['xd_loss']
|
|
||||||
else:
|
|
||||||
strike = True
|
|
||||||
if outputs['zd_loss'] < zd_lowest_loss:
|
|
||||||
zd_lowest_loss = outputs['zd_loss']
|
|
||||||
else:
|
|
||||||
strike = True
|
|
||||||
|
|
||||||
if strike and total_strike:
|
if strike and total_strike:
|
||||||
grace_period -= 1
|
grace_period -= 1
|
||||||
@ -309,8 +291,6 @@ def _train_one_epoch_simple(epoch: int,
|
|||||||
epoch_var.assign(epoch)
|
epoch_var.assign(epoch)
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
# define loss variables
|
# define loss variables
|
||||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
|
||||||
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
|
|
||||||
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
|
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
|
||||||
|
|
||||||
# update learning rate
|
# update learning rate
|
||||||
@ -343,8 +323,6 @@ def _train_one_epoch_simple(epoch: int,
|
|||||||
|
|
||||||
# final losses of epoch
|
# final losses of epoch
|
||||||
outputs = {
|
outputs = {
|
||||||
'decoder_loss': decoder_loss_avg.result(False),
|
|
||||||
'encoder_loss': encoder_loss_avg.result(False),
|
|
||||||
'enc_dec_loss': enc_dec_loss_avg.result(False),
|
'enc_dec_loss': enc_dec_loss_avg.result(False),
|
||||||
'per_epoch_time': per_epoch_time,
|
'per_epoch_time': per_epoch_time,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user