diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 22490bf..dc71eb2 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -194,6 +194,7 @@ def train(dataset: tf.data.Dataset, iteration: int, beta1=0.5, beta2=0.999), # global step counter 'epoch_var': k.variable(-1, dtype=tf.int64), + 'global_step': k.variable(0, dtype=tf.int64), 'global_step_decoder': k.variable(0, dtype=tf.int64), 'global_step_enc_dec': k.variable(0, dtype=tf.int64), 'global_step_xd': k.variable(0, dtype=tf.int64), @@ -235,12 +236,14 @@ def train(dataset: tf.data.Dataset, iteration: int, )) # save sample image summary - def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None: + def _save_sample(decoder: Decoder, global_step: tf.Variable, **kwargs) -> None: resultsample = decoder(sample).cpu() grid = prepare_image(resultsample) summary_ops_v2.image(name='sample', tensor=k.expand_dims(grid, axis=0), - max_images=1, step=global_step_decoder) - _save_sample(**checkpointables) + max_images=1, step=global_step) + + with summary_ops_v2.always_record_summaries(): + _save_sample(**checkpointables) # save weights at end of epoch checkpoint.save(checkpoint_prefix) @@ -305,13 +308,14 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens x_discriminator_optimizer: tf.train.Optimizer, z_discriminator_optimizer: tf.train.Optimizer, enc_dec_optimizer: tf.train.Optimizer, + global_step: tf.Variable, global_step_xd: tf.Variable, global_step_zd: tf.Variable, global_step_decoder: tf.Variable, global_step_enc_dec: tf.Variable, epoch_var: tf.Variable) -> Dict[str, float]: with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY, - global_step=global_step_decoder): + global_step=global_step): epoch_var.assign(epoch) epoch_start_time = time.time() # define loss variables @@ -382,6 +386,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens step=global_step_decoder) batch_iteration.assign_add(1) + global_step.assign_add(1) epoch_end_time = time.time() per_epoch_time = epoch_end_time - epoch_start_time