diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 6418eb2..70a4154 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -190,6 +190,7 @@ def train(dataset: tf.data.Dataset, iteration: int, 'x_discriminator_optimizer': AdamOptimizer(learning_rate=checkpointables['learning_rate_var'], beta1=0.5, beta2=0.999), # global step counter + 'epoch_var': k.variable(-1, dtype=tf.int64), 'global_step_decoder': k.variable(0, dtype=tf.int64), 'global_step_enc_dec': k.variable(0, dtype=tf.int64), 'global_step_xd': k.variable(0, dtype=tf.int64), @@ -204,15 +205,24 @@ def train(dataset: tf.data.Dataset, iteration: int, checkpoint = tf.train.Checkpoint(**checkpointables) checkpoint.restore(latest_checkpoint) - for epoch in range(train_epoch): - outputs = _train_one_epoch(epoch, dataset, targets_real=y_real, - targets_fake=y_fake, z_generator= z_generator, + def _get_last_epoch(epoch_var: tf.Variable, **kwargs) -> int: + return int(epoch_var) + + last_epoch = _get_last_epoch(**checkpointables) + previous_epochs = 0 + if last_epoch != -1: + previous_epochs = last_epoch + 1 + + for epoch in range(train_epoch - previous_epochs): + _epoch = epoch + previous_epochs + outputs = _train_one_epoch(_epoch, dataset, targets_real=y_real, + targets_fake=y_fake, z_generator=z_generator, verbose=verbose, **checkpointables) if verbose: print(( - f"[{epoch + 1:d}/{train_epoch:d}] - " + f"[{_epoch + 1:d}/{train_epoch:d}] - " f"train time: {outputs['per_epoch_time']:.2f}, " f"Decoder loss: {outputs['decoder_loss']:.3f}, " f"X Discriminator loss: {outputs['xd_loss']:.3f}, " @@ -222,7 +232,7 @@ def train(dataset: tf.data.Dataset, iteration: int, )) # save sample image summary - def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable) -> None: + def _save_sample(decoder: Decoder, global_step_decoder: tf.Variable, **kwargs) -> None: resultsample = decoder(sample).cpu() grid = prepare_image(resultsample) summary_ops_v2.image(name='sample_' + str(epoch), tensor=k.expand_dims(grid, axis=0), @@ -293,8 +303,10 @@ def _train_one_epoch(epoch: int, dataset: tf.data.Dataset, targets_real: tf.Tens enc_dec_optimizer: tf.train.Optimizer, global_step_xd: tf.Variable, global_step_zd: tf.Variable, global_step_decoder: tf.Variable, - global_step_enc_dec: tf.Variable) -> Dict[str, float]: + global_step_enc_dec: tf.Variable, + epoch_var: tf.Variable) -> Dict[str, float]: + epoch_var.assign(epoch) epoch_start_time = time.time() # define loss variables encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)