diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index e67e7b3..50ff5ac 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -161,10 +161,14 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, z_discriminator_optimizer=z_discriminator_optimizer, x_discriminator_optimizer=x_discriminator_optimizer, enc_dec_optimizer=enc_dec_optimizer, - step_counter=global_step_decoder) + global_step_decoder=global_step_decoder, + global_step_enc_dec=global_step_enc_dec, + global_step_xd=global_step_xd, + global_step_zd=global_step_zd, + learning_rate_var=learning_rate_var) if latest_checkpoint is not None: # if there is a checkpoint in the current training iteration, proceed from there - checkpoint.restore(latest_checkpoint).assert_consumed() + checkpoint.restore(latest_checkpoint) for epoch in range(train_epoch): # define loss variables