Improved summary logging
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -24,6 +24,7 @@ Attributes:
|
|||||||
GRACE: specifies the number of epochs that the training loss can stagnate or worsen
|
GRACE: specifies the number of epochs that the training loss can stagnate or worsen
|
||||||
before the training is stopped early
|
before the training is stopped early
|
||||||
TOTAL_LOSS_GRACE_CAP: upper limit for total loss, grace countdown only enabled if total loss higher
|
TOTAL_LOSS_GRACE_CAP: upper limit for total loss, grace countdown only enabled if total loss higher
|
||||||
|
LOG_FREQUENCY: number of steps that must pass before logging happens
|
||||||
|
|
||||||
Functions:
|
Functions:
|
||||||
prepare_training_data(...): prepares the mnist training data
|
prepare_training_data(...): prepares the mnist training data
|
||||||
@ -53,6 +54,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
|
|||||||
|
|
||||||
GRACE: int = 10
|
GRACE: int = 10
|
||||||
TOTAL_LOSS_GRACE_CAP: int = 6
|
TOTAL_LOSS_GRACE_CAP: int = 6
|
||||||
|
LOG_FREQUENCY: int = 10
|
||||||
|
|
||||||
|
|
||||||
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
|
def prepare_training_data(test_fold_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||||
@ -236,7 +238,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
|
|||||||
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None:
|
def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None:
|
||||||
resultsample = decoder(sample).cpu()
|
resultsample = decoder(sample).cpu()
|
||||||
grid = prepare_image(resultsample)
|
grid = prepare_image(resultsample)
|
||||||
summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0),
|
summary_ops_v2.image(name='sample', tensor=k.expand_dims(grid, axis=0),
|
||||||
max_images=1, step=global_step_decoder)
|
max_images=1, step=global_step_decoder)
|
||||||
_save_sample(**checkpointables)
|
_save_sample(**checkpointables)
|
||||||
|
|
||||||
@ -308,100 +310,93 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
|
|||||||
global_step_enc_dec: tf.Variable,
|
global_step_enc_dec: tf.Variable,
|
||||||
epoch_var: tf.Variable) -> Dict[str, float]:
|
epoch_var: tf.Variable) -> Dict[str, float]:
|
||||||
|
|
||||||
epoch_var.assign(epoch)
|
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY,
|
||||||
epoch_start_time = time.time()
|
global_step=global_step_decoder):
|
||||||
# define loss variables
|
epoch_var.assign(epoch)
|
||||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
epoch_start_time = time.time()
|
||||||
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
|
# define loss variables
|
||||||
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
|
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||||
zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32)
|
decoder_loss_avg = tfe.metrics.Mean(name='decoder_loss', dtype=tf.float32)
|
||||||
xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32)
|
enc_dec_loss_avg = tfe.metrics.Mean(name='encoder_decoder_loss', dtype=tf.float32)
|
||||||
|
zd_loss_avg = tfe.metrics.Mean(name='z_discriminator_loss', dtype=tf.float32)
|
||||||
# update learning rate
|
xd_loss_avg = tfe.metrics.Mean(name='x_discriminator_loss', dtype=tf.float32)
|
||||||
if (epoch + 1) % 30 == 0:
|
|
||||||
learning_rate_var.assign(learning_rate_var.value() / 4)
|
|
||||||
if verbose:
|
|
||||||
print("learning rate change!")
|
|
||||||
|
|
||||||
log_frequency = 10
|
|
||||||
batch_iteration = k.variable(0, dtype=tf.int64)
|
|
||||||
for x, _ in dataset:
|
|
||||||
# x discriminator
|
|
||||||
_xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator,
|
|
||||||
decoder=decoder,
|
|
||||||
optimizer=x_discriminator_optimizer,
|
|
||||||
inputs=x,
|
|
||||||
targets_real=targets_real,
|
|
||||||
targets_fake=targets_fake,
|
|
||||||
global_step=global_step_xd,
|
|
||||||
z_generator=z_generator)
|
|
||||||
xd_loss_avg(_xd_train_loss)
|
|
||||||
|
|
||||||
# --------
|
# update learning rate
|
||||||
# decoder
|
if (epoch + 1) % 30 == 0:
|
||||||
_decoder_train_loss = _train_decoder_step(decoder=decoder,
|
learning_rate_var.assign(learning_rate_var.value() / 4)
|
||||||
x_discriminator=x_discriminator,
|
if verbose:
|
||||||
optimizer=decoder_optimizer,
|
print("learning rate change!")
|
||||||
targets=targets_real,
|
|
||||||
global_step=global_step_decoder,
|
|
||||||
z_generator=z_generator)
|
|
||||||
decoder_loss_avg(_decoder_train_loss)
|
|
||||||
|
|
||||||
# ---------
|
batch_iteration = k.variable(0, dtype=tf.int64)
|
||||||
# z discriminator
|
for x, _ in dataset:
|
||||||
_zd_train_loss = _train_zdiscriminator_step(z_discriminator=z_discriminator,
|
# x discriminator
|
||||||
encoder=encoder,
|
_xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator,
|
||||||
optimizer=z_discriminator_optimizer,
|
decoder=decoder,
|
||||||
inputs=x,
|
optimizer=x_discriminator_optimizer,
|
||||||
targets_real=targets_real,
|
inputs=x,
|
||||||
targets_fake=targets_fake,
|
targets_real=targets_real,
|
||||||
global_step=global_step_zd,
|
targets_fake=targets_fake,
|
||||||
z_generator=z_generator)
|
global_step=global_step_xd,
|
||||||
zd_loss_avg(_zd_train_loss)
|
z_generator=z_generator)
|
||||||
|
xd_loss_avg(_xd_train_loss)
|
||||||
# -----------
|
|
||||||
# encoder + decoder
|
# --------
|
||||||
encoder_loss, reconstruction_loss, x_decoded = _train_enc_dec_step(encoder=encoder,
|
# decoder
|
||||||
decoder=decoder,
|
_decoder_train_loss = _train_decoder_step(decoder=decoder,
|
||||||
z_discriminator=z_discriminator,
|
x_discriminator=x_discriminator,
|
||||||
optimizer=enc_dec_optimizer,
|
optimizer=decoder_optimizer,
|
||||||
inputs=x,
|
targets=targets_real,
|
||||||
targets=targets_real,
|
global_step=global_step_decoder,
|
||||||
global_step=global_step_enc_dec)
|
z_generator=z_generator)
|
||||||
enc_dec_loss_avg(reconstruction_loss)
|
decoder_loss_avg(_decoder_train_loss)
|
||||||
encoder_loss_avg(encoder_loss)
|
|
||||||
|
# ---------
|
||||||
if int(global_step_decoder % log_frequency) == 0:
|
# z discriminator
|
||||||
# log the losses every log frequency batches
|
_zd_train_loss = _train_zdiscriminator_step(z_discriminator=z_discriminator,
|
||||||
summary_ops_v2.scalar('encoder_loss', encoder_loss_avg.result(False), step=global_step_enc_dec)
|
encoder=encoder,
|
||||||
summary_ops_v2.scalar('decoder_loss', decoder_loss_avg.result(False), step=global_step_decoder)
|
optimizer=z_discriminator_optimizer,
|
||||||
summary_ops_v2.scalar('encoder_decoder_loss', enc_dec_loss_avg.result(False), step=global_step_enc_dec)
|
inputs=x,
|
||||||
summary_ops_v2.scalar('z_discriminator_loss', zd_loss_avg.result(False), step=global_step_zd)
|
targets_real=targets_real,
|
||||||
summary_ops_v2.scalar('x_discriminator_loss', xd_loss_avg.result(False), step=global_step_xd)
|
targets_fake=targets_fake,
|
||||||
|
global_step=global_step_zd,
|
||||||
if int(batch_iteration) == 0:
|
z_generator=z_generator)
|
||||||
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
zd_loss_avg(_zd_train_loss)
|
||||||
grid = prepare_image(comparison.cpu(), nrow=64)
|
|
||||||
summary_ops_v2.image(name='reconstruction_' + str(epoch),
|
# -----------
|
||||||
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
# encoder + decoder
|
||||||
step=global_step_decoder)
|
encoder_loss, reconstruction_loss, x_decoded = _train_enc_dec_step(encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
batch_iteration.assign_add(1)
|
z_discriminator=z_discriminator,
|
||||||
|
optimizer=enc_dec_optimizer,
|
||||||
|
inputs=x,
|
||||||
|
targets=targets_real,
|
||||||
|
global_step=global_step_enc_dec)
|
||||||
|
enc_dec_loss_avg(reconstruction_loss)
|
||||||
|
encoder_loss_avg(encoder_loss)
|
||||||
|
|
||||||
epoch_end_time = time.time()
|
if int(batch_iteration) == 0:
|
||||||
per_epoch_time = epoch_end_time - epoch_start_time
|
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
|
||||||
|
grid = prepare_image(comparison.cpu(), nrow=64)
|
||||||
# final losses of epoch
|
summary_ops_v2.image(name='reconstruction',
|
||||||
outputs = {
|
tensor=k.expand_dims(grid, axis=0), max_images=1,
|
||||||
'decoder_loss': decoder_loss_avg.result(False),
|
step=global_step_decoder)
|
||||||
'encoder_loss': encoder_loss_avg.result(False),
|
|
||||||
'enc_dec_loss': enc_dec_loss_avg.result(False),
|
batch_iteration.assign_add(1)
|
||||||
'xd_loss': xd_loss_avg.result(False),
|
|
||||||
'zd_loss': zd_loss_avg.result(False),
|
epoch_end_time = time.time()
|
||||||
'per_epoch_time': per_epoch_time,
|
per_epoch_time = epoch_end_time - epoch_start_time
|
||||||
}
|
|
||||||
|
# final losses of epoch
|
||||||
return outputs
|
outputs = {
|
||||||
|
'decoder_loss': decoder_loss_avg.result(False),
|
||||||
|
'encoder_loss': encoder_loss_avg.result(False),
|
||||||
|
'enc_dec_loss': enc_dec_loss_avg.result(False),
|
||||||
|
'xd_loss': xd_loss_avg.result(False),
|
||||||
|
'zd_loss': zd_loss_avg.result(False),
|
||||||
|
'per_epoch_time': per_epoch_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||||
@ -432,7 +427,13 @@ def _train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder
|
|||||||
xd_fake_loss = binary_crossentropy(targets_fake, xd_result_2)
|
xd_fake_loss = binary_crossentropy(targets_fake, xd_result_2)
|
||||||
|
|
||||||
_xd_train_loss = xd_real_loss + xd_fake_loss
|
_xd_train_loss = xd_real_loss + xd_fake_loss
|
||||||
|
|
||||||
|
summary_ops_v2.scalar(name='x_discriminator_real_loss', tensor=xd_real_loss,
|
||||||
|
step=global_step)
|
||||||
|
summary_ops_v2.scalar(name='x_discriminator_fake_loss', tensor=xd_fake_loss,
|
||||||
|
step=global_step)
|
||||||
|
summary_ops_v2.scalar(name='x_discriminator_loss', tensor=_xd_train_loss,
|
||||||
|
step=global_step)
|
||||||
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
|
xd_grads = tape.gradient(_xd_train_loss, x_discriminator.trainable_variables)
|
||||||
optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
|
optimizer.apply_gradients(zip(xd_grads, x_discriminator.trainable_variables),
|
||||||
global_step=global_step)
|
global_step=global_step)
|
||||||
@ -461,7 +462,9 @@ def _train_decoder_step(decoder: Decoder, x_discriminator: XDiscriminator,
|
|||||||
x_fake = decoder(z)
|
x_fake = decoder(z)
|
||||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
xd_result = tf.squeeze(x_discriminator(x_fake))
|
||||||
_decoder_train_loss = binary_crossentropy(targets, xd_result)
|
_decoder_train_loss = binary_crossentropy(targets, xd_result)
|
||||||
|
|
||||||
|
summary_ops_v2.scalar(name='decoder_loss', tensor=_decoder_train_loss,
|
||||||
|
step=global_step)
|
||||||
grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
|
grads = tape.gradient(_decoder_train_loss, decoder.trainable_variables)
|
||||||
optimizer.apply_gradients(zip(grads, decoder.trainable_variables),
|
optimizer.apply_gradients(zip(grads, decoder.trainable_variables),
|
||||||
global_step=global_step)
|
global_step=global_step)
|
||||||
@ -498,7 +501,13 @@ def _train_zdiscriminator_step(z_discriminator: ZDiscriminator, encoder: Encoder
|
|||||||
zd_fake_loss = binary_crossentropy(targets_fake, zd_result)
|
zd_fake_loss = binary_crossentropy(targets_fake, zd_result)
|
||||||
|
|
||||||
_zd_train_loss = zd_real_loss + zd_fake_loss
|
_zd_train_loss = zd_real_loss + zd_fake_loss
|
||||||
|
|
||||||
|
summary_ops_v2.scalar(name='z_discriminator_real_loss', tensor=zd_real_loss,
|
||||||
|
step=global_step)
|
||||||
|
summary_ops_v2.scalar(name='z_discriminator_fake_loss', tensor=zd_fake_loss,
|
||||||
|
step=global_step)
|
||||||
|
summary_ops_v2.scalar(name='z_discriminator_loss', tensor=_zd_train_loss,
|
||||||
|
step=global_step)
|
||||||
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
|
zd_grads = tape.gradient(_zd_train_loss, z_discriminator.trainable_variables)
|
||||||
optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
|
optimizer.apply_gradients(zip(zd_grads, z_discriminator.trainable_variables),
|
||||||
global_step=global_step)
|
global_step=global_step)
|
||||||
@ -529,7 +538,13 @@ def _train_enc_dec_step(encoder: Encoder, decoder: Decoder, z_discriminator: ZDi
|
|||||||
encoder_loss = binary_crossentropy(targets, zd_result) * 2.0
|
encoder_loss = binary_crossentropy(targets, zd_result) * 2.0
|
||||||
reconstruction_loss = binary_crossentropy(inputs, x_decoded)
|
reconstruction_loss = binary_crossentropy(inputs, x_decoded)
|
||||||
_enc_dec_train_loss = encoder_loss + reconstruction_loss
|
_enc_dec_train_loss = encoder_loss + reconstruction_loss
|
||||||
|
|
||||||
|
summary_ops_v2.scalar(name='encoder_loss', tensor=encoder_loss,
|
||||||
|
step=global_step)
|
||||||
|
summary_ops_v2.scalar(name='reconstruction_loss', tensor=reconstruction_loss,
|
||||||
|
step=global_step)
|
||||||
|
summary_ops_v2.scalar(name='encoder_decoder_loss', tensor=_enc_dec_train_loss,
|
||||||
|
step=global_step)
|
||||||
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
|
enc_dec_grads = tape.gradient(_enc_dec_train_loss,
|
||||||
encoder.trainable_variables + decoder.trainable_variables)
|
encoder.trainable_variables + decoder.trainable_variables)
|
||||||
optimizer.apply_gradients(zip(enc_dec_grads,
|
optimizer.apply_gradients(zip(enc_dec_grads,
|
||||||
@ -570,6 +585,6 @@ if __name__ == "__main__":
|
|||||||
total_classes=10)
|
total_classes=10)
|
||||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
with train_summary_writer.as_default():
|
||||||
train(dataset=train_dataset, iteration=iteration,
|
train(dataset=train_dataset, iteration=iteration,
|
||||||
weights_prefix='weights/' + str(inlier_classes[0]) + '/')
|
weights_prefix='weights/' + str(inlier_classes[0]) + '/')
|
||||||
|
|||||||
Reference in New Issue
Block a user