Added saving of weights after each epoch

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 14:22:27 +01:00
parent ec9b64a757
commit 3c5f332880

View File

@ -249,6 +249,16 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
im = Image.fromarray(ndarr)
im.save(filename)
# save weights at end of epoch
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/epoch-' + str(epoch))
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/epoch-' + str(epoch))
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/epoch-' + str(epoch))
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/epoch-' + str(epoch))
# check for improvements in error reduction - otherwise early stopping
strike = False
total_strike = False
@ -295,10 +305,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
print("Training stopped early!... save model weights")
# save trained models
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration))
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration))
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/final')
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
str(iteration) + '/final')
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,