Added more parameters to checkpoint
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -161,10 +161,14 @@ def train(dataset: tf.data.Dataset, iteration: int, result_prefix: str,
|
|||||||
z_discriminator_optimizer=z_discriminator_optimizer,
|
z_discriminator_optimizer=z_discriminator_optimizer,
|
||||||
x_discriminator_optimizer=x_discriminator_optimizer,
|
x_discriminator_optimizer=x_discriminator_optimizer,
|
||||||
enc_dec_optimizer=enc_dec_optimizer,
|
enc_dec_optimizer=enc_dec_optimizer,
|
||||||
step_counter=global_step_decoder)
|
global_step_decoder=global_step_decoder,
|
||||||
|
global_step_enc_dec=global_step_enc_dec,
|
||||||
|
global_step_xd=global_step_xd,
|
||||||
|
global_step_zd=global_step_zd,
|
||||||
|
learning_rate_var=learning_rate_var)
|
||||||
if latest_checkpoint is not None:
|
if latest_checkpoint is not None:
|
||||||
# if there is a checkpoint in the current training iteration, proceed from there
|
# if there is a checkpoint in the current training iteration, proceed from there
|
||||||
checkpoint.restore(latest_checkpoint).assert_consumed()
|
checkpoint.restore(latest_checkpoint)
|
||||||
|
|
||||||
for epoch in range(train_epoch):
|
for epoch in range(train_epoch):
|
||||||
# define loss variables
|
# define loss variables
|
||||||
|
|||||||
Reference in New Issue
Block a user