Added names and types to loss metrics

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-07 18:06:33 +01:00
parent cce0975a13
commit c52a3cdfb9

View File

@ -110,11 +110,11 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
for epoch in range(train_epoch): for epoch in range(train_epoch):
# define loss variables # define loss variables
encoder_loss_avg = tfe.metrics.Mean() encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
decoder_loss_avg = tfe.metrics.Mean() decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
enc_dec_loss_avg = tfe.metrics.Mean() enc_dec_loss_avg = tfe.metrics.Mean(name='enc_dec_loss', dtype=tf.float32)
zd_loss_avg = tfe.metrics.Mean() zd_loss_avg = tfe.metrics.Mean(name='zd_loss', dtype=tf.float32)
xd_loss_avg = tfe.metrics.Mean() xd_loss_avg = tfe.metrics.Mean(name='xd_loss', dtype=tf.float32)
epoch_start_time = time.time() epoch_start_time = time.time()