diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index dc71eb2..aac4e03 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -194,7 +194,7 @@ def train(dataset: tf.data.Dataset, iteration: int, beta1=0.5, beta2=0.999), # global step counter 'epoch_var': k.variable(-1, dtype=tf.int64), - 'global_step': k.variable(0, dtype=tf.int64), + 'global_step': tf.train.get_or_create_global_step(), 'global_step_decoder': k.variable(0, dtype=tf.int64), 'global_step_enc_dec': k.variable(0, dtype=tf.int64), 'global_step_xd': k.variable(0, dtype=tf.int64), @@ -314,8 +314,7 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens global_step_enc_dec: tf.Variable, epoch_var: tf.Variable) -> Dict[str, float]: - with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY, - global_step=global_step): + with summary_ops_v2.record_summaries_every_n_global_steps(n=LOG_FREQUENCY): epoch_var.assign(epoch) epoch_start_time = time.time() # define loss variables @@ -331,7 +330,6 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens if verbose: print("learning rate change!") - batch_iteration = k.variable(0, dtype=tf.int64) for x, _ in dataset: # x discriminator _xd_train_loss = _train_xdiscriminator_step(x_discriminator=x_discriminator, @@ -378,14 +376,11 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens enc_dec_loss_avg(reconstruction_loss) encoder_loss_avg(encoder_loss) - if int(batch_iteration) == 0: - comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0) - grid = prepare_image(comparison.cpu(), nrow=64) - summary_ops_v2.image(name='reconstruction', - tensor=k.expand_dims(grid, axis=0), max_images=1, - step=global_step_decoder) - - batch_iteration.assign_add(1) + comparison = k.concatenate([x[:64], x_decoded[:64]], axis=0) + grid = prepare_image(comparison.cpu(), nrow=64) + summary_ops_v2.image(name='reconstruction', + tensor=k.expand_dims(grid, axis=0), max_images=1, + step=global_step) global_step.assign_add(1) epoch_end_time = time.time()