Added verbosity flag

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 07:06:16 +01:00
parent 1041f24b99
commit d9cd24f769

View File

@ -39,7 +39,7 @@ binary_crossentropy = tf.keras.losses.binary_crossentropy
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
channels: int = 1, zsize: int = 32, lr: float = 0.002,
batch_size: int = 128, train_epoch: int = 80,
folds: int = 5):
folds: int = 5, verbose: bool = True):
"""
Train AAE for mnist data set.
@ -52,6 +52,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
:param batch_size: size of each batch
:param train_epoch: number of epochs to train
:param folds: number of folds available
:param verbose: if True prints train progress info to console
"""
# prepare data
mnist_train = []
@ -134,7 +135,8 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
# update learning rate
if (epoch + 1) % 30 == 0:
learning_rate_var.assign(learning_rate_var.value() / 4)
print("learning rate change!")
if verbose:
print("learning rate change!")
nr_batches = len(mnist_train_x) // batch_size
log_frequency = 10
@ -209,15 +211,16 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
print((
f"[{epoch + 1:d}/{train_epoch:d}] - "
f"train time: {per_epoch_time:.2f}, "
f"Decoder loss: {decoder_loss_avg.result()}, X Discriminator loss: {xd_loss_avg.result():.3f}, "
f"Z Discriminator loss: {zd_loss_avg.result():.3f}, "
f"Encoder + Decoder loss: {enc_dec_loss_avg.result():.3f}, "
f"Encoder loss: {encoder_loss_avg.result():.3f}"
))
if verbose:
print((
f"[{epoch + 1:d}/{train_epoch:d}] - "
f"train time: {per_epoch_time:.2f}, "
f"Decoder loss: {decoder_loss_avg.result()}, X Discriminator loss: {xd_loss_avg.result():.3f}, "
f"Z Discriminator loss: {zd_loss_avg.result():.3f}, "
f"Encoder + Decoder loss: {enc_dec_loss_avg.result():.3f}, "
f"Encoder loss: {encoder_loss_avg.result():.3f}"
))
# save sample image
resultsample = decoder(sample).cpu()
@ -225,8 +228,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
os.makedirs(directory, exist_ok=True)
save_image(resultsample,
'results' + str(inlier_classes[0]) + '/sample_' + str(epoch) + '.png')
print("Training finish!... save training results")
if verbose:
print("Training finish!... save training results")
# save trained models
encoder.save_weights("./weights/encoder")
decoder.save_weights("./weights/decoder")
z_discriminator.save_weights("./weights/z_discriminator")