From 720ff63c7da90567cfa48bfae348680ced4f046e Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 14:53:45 +0100 Subject: [PATCH] Updated saving of checkpoints Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 31 +++++++++++++----------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 0a78e8b..b288d10 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -127,6 +127,21 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i total_lowest_loss = math.inf grace_period = GRACE + checkpoint_prefix = './weights/' + str(inlier_classes[0]) + '/' + str(iteration) + '/ckpt' + latest_checkpoint = tf.train.latest_checkpoint(checkpoint_prefix) + checkpoint = tf.train.Checkpoint(encoder=encoder, + decoder=decoder, + z_discriminator=z_discriminator, + x_discriminator=x_discriminator, + decoder_optimizer=decoder_optimizer, + z_discriminator_optimizer=z_discriminator_optimizer, + x_discriminator_optimizer=x_discriminator_optimizer, + enc_dec_optimizer=enc_dec_optimizer, + step_counter=global_step_decoder) + if latest_checkpoint is not None: + # if there is a checkpoint in the current training iteration, proceed from there + checkpoint.restore(latest_checkpoint).assert_consumed() + for epoch in range(train_epoch): # define loss variables encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) @@ -250,14 +265,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i im.save(filename) # save weights at end of epoch - encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + - str(iteration) + '/epoch-' + str(epoch)) - decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + - str(iteration) + '/epoch-' + str(epoch)) - z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + - str(iteration) + '/epoch-' + str(epoch)) - x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + - str(iteration) + '/epoch-' + str(epoch)) + checkpoint.save(checkpoint_prefix) # check for improvements in error reduction - otherwise early stopping strike = False @@ -305,12 +313,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i print("Training stopped early!... save model weights") # save trained models - encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final') - decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final') - z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + - str(iteration) + '/final') - x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + - str(iteration) + '/final') + checkpoint.save(checkpoint_prefix) def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,