Improved names of loss metrics
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -120,9 +120,9 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
# define loss variables
|
# define loss variables
|
||||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||||
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
|
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
|
||||||
enc_dec_loss_avg = tfe.metrics.Mean(name='enc_dec_loss', dtype=tf.float32)
|
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
|
||||||
zd_loss_avg = tfe.metrics.Mean(name='zd_loss', dtype=tf.float32)
|
zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32)
|
||||||
xd_loss_avg = tfe.metrics.Mean(name='xd_loss', dtype=tf.float32)
|
xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32)
|
||||||
|
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user