From d9cd24f769092633edbcb09c2475669593117fde Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 07:06:16 +0100 Subject: [PATCH] Added verbosity flag Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 31 ++++++++++++++---------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 7fb5afd..55dcca5 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -39,7 +39,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int, channels: int = 1, zsize: int = 32, lr: float = 0.002, batch_size: int = 128, train_epoch: int = 80, - folds: int = 5): + folds: int = 5, verbose: bool = True): """ Train AAE for mnist data set. @@ -52,6 +52,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i :param batch_size: size of each batch :param train_epoch: number of epochs to train :param folds: number of folds available + :param verbose: if True prints train progress info to console """ # prepare data mnist_train = [] @@ -134,7 +135,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i # update learning rate if (epoch + 1) % 30 == 0: learning_rate_var.assign(learning_rate_var.value() / 4) - print("learning rate change!") + if verbose: + print("learning rate change!") nr_batches = len(mnist_train_x) // batch_size log_frequency = 10 @@ -209,15 +211,16 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i epoch_end_time = time.time() per_epoch_time = epoch_end_time - epoch_start_time - - print(( - f"[{epoch + 1:d}/{train_epoch:d}] - " - f"train time: {per_epoch_time:.2f}, " - f"Decoder loss: {decoder_loss_avg.result()}, X Discriminator loss: {xd_loss_avg.result():.3f}, " - f"Z Discriminator loss: {zd_loss_avg.result():.3f}, " - f"Encoder + Decoder loss: {enc_dec_loss_avg.result():.3f}, " - f"Encoder loss: {encoder_loss_avg.result():.3f}" - )) + + if verbose: + print(( + f"[{epoch + 1:d}/{train_epoch:d}] - " + f"train time: {per_epoch_time:.2f}, " + f"Decoder loss: {decoder_loss_avg.result()}, X Discriminator loss: {xd_loss_avg.result():.3f}, " + f"Z Discriminator loss: {zd_loss_avg.result():.3f}, " + f"Encoder + Decoder loss: {enc_dec_loss_avg.result():.3f}, " + f"Encoder loss: {encoder_loss_avg.result():.3f}" + )) # save sample image resultsample = decoder(sample).cpu() @@ -225,8 +228,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i os.makedirs(directory, exist_ok=True) save_image(resultsample, 'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png') - - print("Training finish!... save training results") + if verbose: + print("Training finish!... save training results") + + # save trained models encoder.save_weights("./weights/encoder") decoder.save_weights("./weights/decoder") z_discriminator.save_weights("./weights/z_discriminator")