diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 8fa69f2..be94b7e 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -143,6 +143,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i print("learning rate change!") nr_batches = len(mnist_train_x) // batch_size + log_frequency = 10 for it in range(nr_batches): x = k.expand_dims(extract_batch(mnist_train_x, it, batch_size)) # x discriminator @@ -189,6 +190,14 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i global_step=global_step_enc_dec) enc_dec_loss_avg(reconstruction_loss) encoder_loss_avg(encoder_loss) + + if it % log_frequency == 0: + # log the losses every log frequency batches + summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(), step=global_step_enc_dec) + summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(), step=global_step_decoder) + summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(), step=global_step_enc_dec) + summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(), step=global_step_zd) + summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(), step=global_step_xd) if it == 0: directory = 'results' + str(inlier_classes[0])