Added saving of weights after each epoch
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -249,6 +249,16 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(filename)
|
||||
|
||||
# save weights at end of epoch
|
||||
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/epoch-' + str(epoch))
|
||||
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/epoch-' + str(epoch))
|
||||
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/epoch-' + str(epoch))
|
||||
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/epoch-' + str(epoch))
|
||||
|
||||
# check for improvements in error reduction - otherwise early stopping
|
||||
strike = False
|
||||
total_strike = False
|
||||
@ -295,10 +305,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
print("Training stopped early!... save model weights")
|
||||
|
||||
# save trained models
|
||||
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration))
|
||||
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration))
|
||||
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
|
||||
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
|
||||
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
|
||||
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
|
||||
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/final')
|
||||
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||
str(iteration) + '/final')
|
||||
|
||||
|
||||
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||
|
||||
Reference in New Issue
Block a user