diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 9f546e6..64b3da1 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -120,9 +120,9 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i # define loss variables encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32) - enc_dec_loss_avg = tfe.metrics.Mean(name='enc_dec_loss', dtype=tf.float32) - zd_loss_avg = tfe.metrics.Mean(name='zd_loss', dtype=tf.float32) - xd_loss_avg = tfe.metrics.Mean(name='xd_loss', dtype=tf.float32) + enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32) + zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32) + xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32) epoch_start_time = time.time()