Added more parameters to checkpoint

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 15:45:27 +01:00
parent 658973d90d
commit 6f36aa7faf

View File

@ -161,10 +161,14 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
z_discriminator_optimizer=z_discriminator_optimizer, z_discriminator_optimizer=z_discriminator_optimizer,
x_discriminator_optimizer=x_discriminator_optimizer, x_discriminator_optimizer=x_discriminator_optimizer,
enc_dec_optimizer=enc_dec_optimizer, enc_dec_optimizer=enc_dec_optimizer,
step_counter=global_step_decoder) global_step_decoder=global_step_decoder,
global_step_enc_dec=global_step_enc_dec,
global_step_xd=global_step_xd,
global_step_zd=global_step_zd,
learning_rate_var=learning_rate_var)
if latest_checkpoint is not None: if latest_checkpoint is not None:
# if there is a checkpoint in the current training iteration, proceed from there # if there is a checkpoint in the current training iteration, proceed from there
checkpoint.restore(latest_checkpoint).assert_consumed() checkpoint.restore(latest_checkpoint)
for epoch in range(train_epoch): for epoch in range(train_epoch):
# define loss variables # define loss variables