From c52a3cdfb9eed54f89742bf388cd6e3b45bb428f Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Thu, 7 Feb 2019 18:06:33 +0100 Subject: [PATCH] Added names and types to loss metrics Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 8f2ed77..851ecce 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -110,11 +110,11 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i for epoch in range(train_epoch): # define loss variables - encoder_loss_avg = tfe.metrics.Mean() - decoder_loss_avg = tfe.metrics.Mean() - enc_dec_loss_avg = tfe.metrics.Mean() - zd_loss_avg = tfe.metrics.Mean() - xd_loss_avg = tfe.metrics.Mean() + encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) + decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32) + enc_dec_loss_avg = tfe.metrics.Mean(name='enc_dec_loss', dtype=tf.float32) + zd_loss_avg = tfe.metrics.Mean(name='zd_loss', dtype=tf.float32) + xd_loss_avg = tfe.metrics.Mean(name='xd_loss', dtype=tf.float32) epoch_start_time = time.time()