From 3b9742a1b492de342373affe9c086a043ec2ee76 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 18:40:41 +0100 Subject: [PATCH] Improved summary logging Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 209 ++++++++++++----------- 1 file changed, 112 insertions(+), 97 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index aeb176c..22490bf 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -24,6 +24,7 @@ Attributes: GRACE: specifies the number of epochs that the training loss can stagnate or worsen before the training is stopped early TOTAL_LOSS_GRACE_CAP: upper limit for total loss, grace countdown only enabled if total loss higher + LOG_FREQUENCY: number of steps that must pass before logging happens Functions: prepare_training_data(...): prepares the mnist training data @@ -53,6 +54,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy GRACE: int = 10 TOTAL_LOSS_GRACE_CAP: int = 6 +LOG_FREQUENCY: int = 10 def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int, @@ -236,7 +238,7 @@ def train(dataset: tf.data.Dataset, iteration: int, def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None: resultsample = decoder(sample).cpu() grid = prepare_image(resultsample) - summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), + summary_ops_v2.image(name='sample', tensor=k.expand_dims(grid, axis=0), max_images=1, step=global_step_decoder) _save_sample(**checkpointables) @@ -308,100 +310,93 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens global_step_enc_dec: tf.Variable, epoch_var: tf.Variable) -> Dict[str, float]: - epoch_var.assign(epoch) - epoch_start_time = time.time() - # define loss variables - encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) - decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32) - enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32) - zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32) - xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32) - - # update learning rate - if (epoch + 1) % 30 == 0: - learning_rate_var.assign(learning_rate_var.value() / 4) - if verbose: - print("learning rate change!") - - log_frequency = 10 - batch_iteration = k.variable(0, dtype=tf.int64) - for x, _ in dataset: - # x discriminator - _xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator, - decoder=decoder, - optimizer=x_discriminator_optimizer, - inputs=x, - targets_real=targets_real, - targets_fake=targets_fake, - global_step=global_step_xd, - z_generator=z_generator) - xd_loss_avg(_xd_train_loss) + with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY, + global_step=global_step_decoder): + epoch_var.assign(epoch) + epoch_start_time = time.time() + # define loss variables + encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) + decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32) + enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32) + zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32) + xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32) - # -------- - # decoder - _decoder_train_loss = _train_decoder_step(decoder=decoder, - x_discriminator=x_discriminator, - optimizer=decoder_optimizer, - targets=targets_real, - global_step=global_step_decoder, - z_generator=z_generator) - decoder_loss_avg(_decoder_train_loss) + # update learning rate + if (epoch + 1) % 30 == 0: + learning_rate_var.assign(learning_rate_var.value() / 4) + if verbose: + print("learning rate change!") - # --------- - # z discriminator - _zd_train_loss = _train_zdiscriminator_step(z_discriminator=z_discriminator, - encoder=encoder, - optimizer=z_discriminator_optimizer, - inputs=x, - targets_real=targets_real, - targets_fake=targets_fake, - global_step=global_step_zd, - z_generator=z_generator) - zd_loss_avg(_zd_train_loss) - - # ----------- - # encoder + decoder - encoder_loss, reconstruction_loss, x_decoded = _train_enc_dec_step(encoder=encoder, - decoder=decoder, - z_discriminator=z_discriminator, - optimizer=enc_dec_optimizer, - inputs=x, - targets=targets_real, - global_step=global_step_enc_dec) - enc_dec_loss_avg(reconstruction_loss) - encoder_loss_avg(encoder_loss) - - if int(global_step_decoder % log_frequency) == 0: - # log the losses every log frequency batches - summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(False), step=global_step_enc_dec) - summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(False), step=global_step_decoder) - summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(False), step=global_step_enc_dec) - summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(False), step=global_step_zd) - summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(False), step=global_step_xd) - - if int(batch_iteration) == 0: - comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0) - grid = prepare_image(comparison.cpu(), nrow=64) - summary_ops_v2.image(name='reconstruction_' + str(epoch), - tensor=k.expand_dims(grid, axis=0), max_images=1, - step=global_step_decoder) - - batch_iteration.assign_add(1) + batch_iteration = k.variable(0, dtype=tf.int64) + for x, _ in dataset: + # x discriminator + _xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator, + decoder=decoder, + optimizer=x_discriminator_optimizer, + inputs=x, + targets_real=targets_real, + targets_fake=targets_fake, + global_step=global_step_xd, + z_generator=z_generator) + xd_loss_avg(_xd_train_loss) + + # -------- + # decoder + _decoder_train_loss = _train_decoder_step(decoder=decoder, + x_discriminator=x_discriminator, + optimizer=decoder_optimizer, + targets=targets_real, + global_step=global_step_decoder, + z_generator=z_generator) + decoder_loss_avg(_decoder_train_loss) + + # --------- + # z discriminator + _zd_train_loss = _train_zdiscriminator_step(z_discriminator=z_discriminator, + encoder=encoder, + optimizer=z_discriminator_optimizer, + inputs=x, + targets_real=targets_real, + targets_fake=targets_fake, + global_step=global_step_zd, + z_generator=z_generator) + zd_loss_avg(_zd_train_loss) + + # ----------- + # encoder + decoder + encoder_loss, reconstruction_loss, x_decoded = _train_enc_dec_step(encoder=encoder, + decoder=decoder, + z_discriminator=z_discriminator, + optimizer=enc_dec_optimizer, + inputs=x, + targets=targets_real, + global_step=global_step_enc_dec) + enc_dec_loss_avg(reconstruction_loss) + encoder_loss_avg(encoder_loss) - epoch_end_time = time.time() - per_epoch_time = epoch_end_time - epoch_start_time - - # final losses of epoch - outputs = { - 'decoder_loss': decoder_loss_avg.result(False), - 'encoder_loss': encoder_loss_avg.result(False), - 'enc_dec_loss': enc_dec_loss_avg.result(False), - 'xd_loss': xd_loss_avg.result(False), - 'zd_loss': zd_loss_avg.result(False), - 'per_epoch_time': per_epoch_time, - } - - return outputs + if int(batch_iteration) == 0: + comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0) + grid = prepare_image(comparison.cpu(), nrow=64) + summary_ops_v2.image(name='reconstruction', + tensor=k.expand_dims(grid, axis=0), max_images=1, + step=global_step_decoder) + + batch_iteration.assign_add(1) + + epoch_end_time = time.time() + per_epoch_time = epoch_end_time - epoch_start_time + + # final losses of epoch + outputs = { + 'decoder_loss': decoder_loss_avg.result(False), + 'encoder_loss': encoder_loss_avg.result(False), + 'enc_dec_loss': enc_dec_loss_avg.result(False), + 'xd_loss': xd_loss_avg.result(False), + 'zd_loss': zd_loss_avg.result(False), + 'per_epoch_time': per_epoch_time, + } + + return outputs def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder, @@ -432,7 +427,13 @@ def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder xd_fake_loss = binary_crossentropy(targets_fake, xd_result_2) _xd_train_loss = xd_real_loss + xd_fake_loss - + + summary_ops_v2.scalar(name='x_discriminator_real_loss', tensor=xd_real_loss, + step=global_step) + summary_ops_v2.scalar(name='x_discriminator_fake_loss', tensor=xd_fake_loss, + step=global_step) + summary_ops_v2.scalar(name='x_discriminator_loss', tensor=_xd_train_loss, + step=global_step) xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables) optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables), global_step=global_step) @@ -461,7 +462,9 @@ def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator, x_fake = decoder(z) xd_result = tf.squeeze(x_discriminator(x_fake)) _decoder_train_loss = binary_crossentropy(targets, xd_result) - + + summary_ops_v2.scalar(name='decoder_loss', tensor=_decoder_train_loss, + step=global_step) grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables) optimizer.apply_gradients(zip(grads, decoder.trainable_variables), global_step=global_step) @@ -498,7 +501,13 @@ def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder zd_fake_loss = binary_crossentropy(targets_fake, zd_result) _zd_train_loss = zd_real_loss + zd_fake_loss - + + summary_ops_v2.scalar(name='z_discriminator_real_loss', tensor=zd_real_loss, + step=global_step) + summary_ops_v2.scalar(name='z_discriminator_fake_loss', tensor=zd_fake_loss, + step=global_step) + summary_ops_v2.scalar(name='z_discriminator_loss', tensor=_zd_train_loss, + step=global_step) zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables) optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables), global_step=global_step) @@ -529,7 +538,13 @@ def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDi encoder_loss = binary_crossentropy(targets, zd_result) * 2.0 reconstruction_loss = binary_crossentropy(inputs, x_decoded) _enc_dec_train_loss = encoder_loss + reconstruction_loss - + + summary_ops_v2.scalar(name='encoder_loss', tensor=encoder_loss, + step=global_step) + summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss, + step=global_step) + summary_ops_v2.scalar(name='encoder_decoder_loss', tensor=_enc_dec_train_loss, + step=global_step) enc_dec_grads = tape.gradient(_enc_dec_train_loss, encoder.trainable_variables + decoder.trainable_variables) optimizer.apply_gradients(zip(enc_dec_grads, @@ -570,6 +585,6 @@ if __name__ == "__main__": total_classes=10) train_summary_writer = summary_ops_v2.create_file_writer( './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) - with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): + with train_summary_writer.as_default(): train(dataset=train_dataset, iteration=iteration, weights_prefix='weights/' + str(inlier_classes[0]) + '/')