From 6f36aa7fafc2363c1eca7affeaff1ce9680f379c Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 15:45:27 +0100 Subject: [PATCH] Added more parameters to checkpoint Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index e67e7b3..50ff5ac 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -161,10 +161,14 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str, z_discriminator_optimizer=z_discriminator_optimizer, x_discriminator_optimizer=x_discriminator_optimizer, enc_dec_optimizer=enc_dec_optimizer, - step_counter=global_step_decoder) + global_step_decoder=global_step_decoder, + global_step_enc_dec=global_step_enc_dec, + global_step_xd=global_step_xd, + global_step_zd=global_step_zd, + learning_rate_var=learning_rate_var) if latest_checkpoint is not None: # if there is a checkpoint in the current training iteration, proceed from there - checkpoint.restore(latest_checkpoint).assert_consumed() + checkpoint.restore(latest_checkpoint) for epoch in range(train_epoch): # define loss variables