Added verbosity flag

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 07:06:16 +01:00
parent 1041f24b99
commit d9cd24f769

View File

@ -39,7 +39,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int, def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
channels: int = 1, zsize: int = 32, lr: float = 0.002, channels: int = 1, zsize: int = 32, lr: float = 0.002,
batch_size: int = 128, train_epoch: int = 80, batch_size: int = 128, train_epoch: int = 80,
folds: int = 5): folds: int = 5, verbose: bool = True):
""" """
Train AAE for mnist data set. Train AAE for mnist data set.
@ -52,6 +52,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
:param batch_size: size of each batch :param batch_size: size of each batch
:param train_epoch: number of epochs to train :param train_epoch: number of epochs to train
:param folds: number of folds available :param folds: number of folds available
:param verbose: if True prints train progress info to console
""" """
# prepare data # prepare data
mnist_train = [] mnist_train = []
@ -134,7 +135,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
# update learning rate # update learning rate
if (epoch + 1) % 30 == 0: if (epoch + 1) % 30 == 0:
learning_rate_var.assign(learning_rate_var.value() / 4) learning_rate_var.assign(learning_rate_var.value() / 4)
print("learning rate change!") if verbose:
print("learning rate change!")
nr_batches = len(mnist_train_x) // batch_size nr_batches = len(mnist_train_x) // batch_size
log_frequency = 10 log_frequency = 10
@ -209,15 +211,16 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
epoch_end_time = time.time() epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time per_epoch_time = epoch_end_time - epoch_start_time
print(( if verbose:
f"[{epoch + 1:d}/{train_epoch:d}] - " print((
f"train time: {per_epoch_time:.2f}, " f"[{epoch + 1:d}/{train_epoch:d}] - "
f"Decoder loss: {decoder_loss_avg.result()}, X Discriminator loss: {xd_loss_avg.result():.3f}, " f"train time: {per_epoch_time:.2f}, "
f"Z Discriminator loss: {zd_loss_avg.result():.3f}, " f"Decoder loss: {decoder_loss_avg.result()}, X Discriminator loss: {xd_loss_avg.result():.3f}, "
f"Encoder + Decoder loss: {enc_dec_loss_avg.result():.3f}, " f"Z Discriminator loss: {zd_loss_avg.result():.3f}, "
f"Encoder loss: {encoder_loss_avg.result():.3f}" f"Encoder + Decoder loss: {enc_dec_loss_avg.result():.3f}, "
)) f"Encoder loss: {encoder_loss_avg.result():.3f}"
))
# save sample image # save sample image
resultsample = decoder(sample).cpu() resultsample = decoder(sample).cpu()
@ -225,8 +228,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
save_image(resultsample, save_image(resultsample,
'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png') 'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png')
if verbose:
print("Training finish!... save training results") print("Training finish!... save training results")
# save trained models
encoder.save_weights("./weights/encoder") encoder.save_weights("./weights/encoder")
decoder.save_weights("./weights/decoder") decoder.save_weights("./weights/decoder")
z_discriminator.save_weights("./weights/z_discriminator") z_discriminator.save_weights("./weights/z_discriminator")