Added saving of weights after each epoch
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -248,6 +248,16 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
ndarr = grid.cpu().numpy()
|
ndarr = grid.cpu().numpy()
|
||||||
im = Image.fromarray(ndarr)
|
im = Image.fromarray(ndarr)
|
||||||
im.save(filename)
|
im.save(filename)
|
||||||
|
|
||||||
|
# save weights at end of epoch
|
||||||
|
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' +
|
||||||
|
str(iteration) + '/epoch-' + str(epoch))
|
||||||
|
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' +
|
||||||
|
str(iteration) + '/epoch-' + str(epoch))
|
||||||
|
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||||
|
str(iteration) + '/epoch-' + str(epoch))
|
||||||
|
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||||
|
str(iteration) + '/epoch-' + str(epoch))
|
||||||
|
|
||||||
# check for improvements in error reduction - otherwise early stopping
|
# check for improvements in error reduction - otherwise early stopping
|
||||||
strike = False
|
strike = False
|
||||||
@ -295,10 +305,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
print("Training stopped early!... save model weights")
|
print("Training stopped early!... save model weights")
|
||||||
|
|
||||||
# save trained models
|
# save trained models
|
||||||
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration))
|
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
|
||||||
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration))
|
decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final')
|
||||||
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
|
z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||||
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration))
|
str(iteration) + '/final')
|
||||||
|
x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' +
|
||||||
|
str(iteration) + '/final')
|
||||||
|
|
||||||
|
|
||||||
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||||
|
|||||||
Reference in New Issue
Block a user