Improved summary logging

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 18:40:41 +01:00
parent 75d7c769c7
commit 3b9742a1b4

View File

@ -24,6 +24,7 @@ Attributes:
GRACE: specifies the number of epochs that the training loss can stagnate or worsen GRACE: specifies the number of epochs that the training loss can stagnate or worsen
before the training is stopped early before the training is stopped early
TOTAL_LOSS_GRACE_CAP: upper limit for total loss, grace countdown only enabled if total loss higher TOTAL_LOSS_GRACE_CAP: upper limit for total loss, grace countdown only enabled if total loss higher
LOG_FREQUENCY: number of steps that must pass before logging happens
Functions: Functions:
prepare_training_data(...): prepares the mnist training data prepare_training_data(...): prepares the mnist training data
@ -53,6 +54,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
GRACE: int = 10 GRACE: int = 10
TOTAL_LOSS_GRACE_CAP: int = 6 TOTAL_LOSS_GRACE_CAP: int = 6
LOG_FREQUENCY: int = 10
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int, def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
@ -236,7 +238,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None: def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None:
resultsample = decoder(sample).cpu() resultsample = decoder(sample).cpu()
grid = prepare_image(resultsample) grid = prepare_image(resultsample)
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), summary_ops_v2.image(name='sample', tensor=k.expand_dims(grid, axis=0),
max_images=1, step=global_step_decoder) max_images=1, step=global_step_decoder)
_save_sample(**checkpointables) _save_sample(**checkpointables)
@ -308,6 +310,8 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
global_step_enc_dec: tf.Variable, global_step_enc_dec: tf.Variable,
epoch_var: tf.Variable) -> Dict[str, float]: epoch_var: tf.Variable) -> Dict[str, float]:
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY,
global_step=global_step_decoder):
epoch_var.assign(epoch) epoch_var.assign(epoch)
epoch_start_time = time.time() epoch_start_time = time.time()
# define loss variables # define loss variables
@ -323,7 +327,6 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
if verbose: if verbose:
print("learning rate change!") print("learning rate change!")
log_frequency = 10
batch_iteration = k.variable(0, dtype=tf.int64) batch_iteration = k.variable(0, dtype=tf.int64)
for x, _ in dataset: for x, _ in dataset:
# x discriminator # x discriminator
@ -371,18 +374,10 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
enc_dec_loss_avg(reconstruction_loss) enc_dec_loss_avg(reconstruction_loss)
encoder_loss_avg(encoder_loss) encoder_loss_avg(encoder_loss)
if int(global_step_decoder % log_frequency) == 0:
# log the losses every log frequency batches
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(False), step=global_step_enc_dec)
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(False), step=global_step_decoder)
summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(False), step=global_step_enc_dec)
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(False), step=global_step_zd)
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(False), step=global_step_xd)
if int(batch_iteration) == 0: if int(batch_iteration) == 0:
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0) comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
grid = prepare_image(comparison.cpu(), nrow=64) grid = prepare_image(comparison.cpu(), nrow=64)
summary_ops_v2.image(name='reconstruction_' + str(epoch), summary_ops_v2.image(name='reconstruction',
tensor=k.expand_dims(grid, axis=0), max_images=1, tensor=k.expand_dims(grid, axis=0), max_images=1,
step=global_step_decoder) step=global_step_decoder)
@ -433,6 +428,12 @@ def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder
_xd_train_loss = xd_real_loss + xd_fake_loss _xd_train_loss = xd_real_loss + xd_fake_loss
summary_ops_v2.scalar(name='x_discriminator_real_loss', tensor=xd_real_loss,
step=global_step)
summary_ops_v2.scalar(name='x_discriminator_fake_loss', tensor=xd_fake_loss,
step=global_step)
summary_ops_v2.scalar(name='x_discriminator_loss', tensor=_xd_train_loss,
step=global_step)
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables) xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables), optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
global_step=global_step) global_step=global_step)
@ -462,6 +463,8 @@ def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
xd_result = tf.squeeze(x_discriminator(x_fake)) xd_result = tf.squeeze(x_discriminator(x_fake))
_decoder_train_loss = binary_crossentropy(targets, xd_result) _decoder_train_loss = binary_crossentropy(targets, xd_result)
summary_ops_v2.scalar(name='decoder_loss', tensor=_decoder_train_loss,
step=global_step)
grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables) grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
optimizer.apply_gradients(zip(grads, decoder.trainable_variables), optimizer.apply_gradients(zip(grads, decoder.trainable_variables),
global_step=global_step) global_step=global_step)
@ -499,6 +502,12 @@ def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder
_zd_train_loss = zd_real_loss + zd_fake_loss _zd_train_loss = zd_real_loss + zd_fake_loss
summary_ops_v2.scalar(name='z_discriminator_real_loss', tensor=zd_real_loss,
step=global_step)
summary_ops_v2.scalar(name='z_discriminator_fake_loss', tensor=zd_fake_loss,
step=global_step)
summary_ops_v2.scalar(name='z_discriminator_loss', tensor=_zd_train_loss,
step=global_step)
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables) zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables), optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
global_step=global_step) global_step=global_step)
@ -530,6 +539,12 @@ def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDi
reconstruction_loss = binary_crossentropy(inputs, x_decoded) reconstruction_loss = binary_crossentropy(inputs, x_decoded)
_enc_dec_train_loss = encoder_loss + reconstruction_loss _enc_dec_train_loss = encoder_loss + reconstruction_loss
summary_ops_v2.scalar(name='encoder_loss', tensor=encoder_loss,
step=global_step)
summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss,
step=global_step)
summary_ops_v2.scalar(name='encoder_decoder_loss', tensor=_enc_dec_train_loss,
step=global_step)
enc_dec_grads = tape.gradient(_enc_dec_train_loss, enc_dec_grads = tape.gradient(_enc_dec_train_loss,
encoder.trainable_variables + decoder.trainable_variables) encoder.trainable_variables + decoder.trainable_variables)
optimizer.apply_gradients(zip(enc_dec_grads, optimizer.apply_gradients(zip(enc_dec_grads,
@ -570,6 +585,6 @@ if __name__ == "__main__":
total_classes=10) total_classes=10)
train_summary_writer = summary_ops_v2.create_file_writer( train_summary_writer = summary_ops_v2.create_file_writer(
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): with train_summary_writer.as_default():
train(dataset=train_dataset, iteration=iteration, train(dataset=train_dataset, iteration=iteration,
weights_prefix='weights/' + str(inlier_classes[0]) + '/') weights_prefix='weights/' + str(inlier_classes[0]) + '/')