From a0094266b378c920fbe37e21e1c04f5de2ca0361 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 13:36:09 +0100 Subject: [PATCH] Added training iteration variable to train function Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index bdca3b3..998c10a 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -37,6 +37,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int, + iteration: int, channels: int = 1, zsize: int = 32, lr: float = 0.002, batch_size: int = 128, train_epoch: int = 80, folds: int = 5, verbose: bool = True): @@ -46,6 +47,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i :param folding_id: id of fold used for test data :param inlier_classes: list of class ids that are considered inliers :param total_classes: total number of classes + :param iteration: identifier for the current training run :param channels: number of channels in input image :param zsize: size of the intermediary z :param lr: learning rate @@ -291,10 +293,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i print("Training stopped early!... save model weights") # save trained models - encoder.save_weights("./weights/encoder/" + str(inlier_classes[0])) - decoder.save_weights("./weights/decoder/" + str(inlier_classes[0])) - z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0])) - x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0])) + encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]) + '/' + str(iteration)) + decoder.save_weights("./weights/decoder/" + str(inlier_classes[0]) + '/' + str(iteration)) + z_discriminator.save_weights("./weights/z_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration)) + x_discriminator.save_weights("./weights/x_discriminator/" + str(inlier_classes[0]) + '/' + str(iteration)) def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder, @@ -475,4 +477,4 @@ if __name__ == "__main__": train_summary_writer = summary_ops_v2.create_file_writer( './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): - train_mnist(folding_id=0, inlier_classes=inlier_classes, total_classes=10) + train_mnist(folding_id=0, inlier_classes=inlier_classes, total_classes=10, iteration=iteration)