diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index d74405d..0a78e8b 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -248,6 +248,16 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i ndarr = grid.cpu().numpy() im = Image.fromarray(ndarr) im.save(filename) + + # save weights at end of epoch + encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + + str(iteration) + '/epoch-' + str(epoch)) + decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + + str(iteration) + '/epoch-' + str(epoch)) + z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + + str(iteration) + '/epoch-' + str(epoch)) + x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + + str(iteration) + '/epoch-' + str(epoch)) # check for improvements in error reduction - otherwise early stopping strike = False @@ -295,10 +305,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i print("Training stopped early!... save model weights") # save trained models - encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration)) - decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration)) - z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration)) - x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration)) + encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final') + decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration) + '/final') + z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + + str(iteration) + '/final') + x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + + str(iteration) + '/final') def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,