Updated saving of checkpoints
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -127,6 +127,21 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
total_lowest_loss = math.inf
|
||||
grace_period = GRACE
|
||||
|
||||
checkpoint_prefix = './weights/' + str(inlier_classes[0]) + '/' + str(iteration) + '/ckpt'
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_prefix)
|
||||
checkpoint = tf.train.Checkpoint(encoder=encoder,
|
||||
decoder=decoder,
|
||||
z_discriminator=z_discriminator,
|
||||
x_discriminator=x_discriminator,
|
||||
decoder_optimizer=decoder_optimizer,
|
||||
z_discriminator_optimizer=z_discriminator_optimizer,
|
||||
x_discriminator_optimizer=x_discriminator_optimizer,
|
||||
enc_dec_optimizer=enc_dec_optimizer,
|
||||
step_counter=global_step_decoder)
|
||||
if latest_checkpoint is not None:
|
||||
# if there is a checkpoint in the current training iteration, proceed from there
|
||||
checkpoint.restore(latest_checkpoint).assert_consumed()
|
||||
|
||||
for epoch in range(train_epoch):
|
||||
# define loss variables
|
||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||
@ -250,14 +265,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
im.save(filename)
|
||||
|
||||
# save weights at end of epoch
|
||||
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/epoch-' + str(epoch))
|
||||
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/epoch-' + str(epoch))
|
||||
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/epoch-' + str(epoch))
|
||||
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/epoch-' + str(epoch))
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
|
||||
# check for improvements in error reduction - otherwise early stopping
|
||||
strike = False
|
||||
@ -305,12 +313,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
print("Training stopped early!... save model weights")
|
||||
|
||||
# save trained models
|
||||
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
|
||||
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
|
||||
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/final')
|
||||
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/final')
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
|
||||
|
||||
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||
|
||||
Reference in New Issue
Block a user