diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 0d0f244..504bf12 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -220,10 +220,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i print("Training finish!... save training results") # save trained models - encoder.save_weights("./weights/encoder/") - decoder.save_weights("./weights/decoder/") - z_discriminator.save_weights("./weights/z_discriminator/") - x_discriminator.save_weights("./weights/x_discriminator/") + encoder.save_weights("./weights/encoder/" + str(inlier_classes[0])) + decoder.save_weights("./weights/decoder/" + str(inlier_classes[0])) + z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0])) + x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0])) def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,