Improved summary logs

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 19:27:33 +01:00
parent a2aa794d47
commit a8f8015ad8

View File

@ -194,7 +194,7 @@ def train(dataset: tf.data.Dataset, iteration: int,
beta1=0.5, beta2=0.999),
# global step counter
'epoch_var': k.variable(-1, dtype=tf.int64),
'global_step': k.variable(0, dtype=tf.int64),
'global_step': tf.train.get_or_create_global_step(),
'global_step_decoder': k.variable(0, dtype=tf.int64),
'global_step_enc_dec': k.variable(0, dtype=tf.int64),
'global_step_xd': k.variable(0, dtype=tf.int64),
@ -314,8 +314,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
global_step_enc_dec: tf.Variable,
epoch_var: tf.Variable) -> Dict[str, float]:
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY,
global_step=global_step):
with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY):
epoch_var.assign(epoch)
epoch_start_time = time.time()
# define loss variables
@ -331,7 +330,6 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
if verbose:
print("learning rate change!")
batch_iteration = k.variable(0, dtype=tf.int64)
for x, _ in dataset:
# x discriminator
_xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator,
@ -378,14 +376,11 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens
enc_dec_loss_avg(reconstruction_loss)
encoder_loss_avg(encoder_loss)
if int(batch_iteration) == 0:
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
grid = prepare_image(comparison.cpu(), nrow=64)
summary_ops_v2.image(name='reconstruction',
tensor=k.expand_dims(grid, axis=0), max_images=1,
step=global_step_decoder)
batch_iteration.assign_add(1)
comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0)
grid = prepare_image(comparison.cpu(), nrow=64)
summary_ops_v2.image(name='reconstruction',
tensor=k.expand_dims(grid, axis=0), max_images=1,
step=global_step)
global_step.assign_add(1)
epoch_end_time = time.time()