Added saving of weights after each epoch

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 14:22:27 +01:00
parent ec9b64a757
commit 3c5f332880

View File

@ -248,6 +248,16 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
ndarr = grid.cpu().numpy() ndarr = grid.cpu().numpy()
im = Image.fromarray(ndarr) im = Image.fromarray(ndarr)
im.save(filename) im.save(filename)
# save weights at end of epoch
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/epoch-' + str(epoch))
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/epoch-' + str(epoch))
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/epoch-' + str(epoch))
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/epoch-' + str(epoch))
# check for improvements in error reduction - otherwise early stopping # check for improvements in error reduction - otherwise early stopping
strike = False strike = False
@ -295,10 +305,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
print("Training stopped early!... save model weights") print("Training stopped early!... save model weights")
# save trained models # save trained models
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration)) encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration)) decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration)) z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration)) str(iteration) + '/final')
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/final')
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder, def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,