From c19ff19820f305c5437bdaf4c50e4c1157ca7d9c Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 13:20:17 +0100 Subject: [PATCH] Implemented early stopping Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 72 ++++++++++++++++++++---- 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index cdddfdf..5bdc2f8 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -16,6 +16,7 @@ """aae.train.py: contains training functionality""" import functools +import math import os import pickle import time @@ -114,6 +115,13 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i global_step_xd = k.variable(0, dtype=tf.int64) global_step_zd = k.variable(0, dtype=tf.int64) + encoder_lowest_loss = math.inf + decoder_lowest_loss = math.inf + enc_dec_lowest_loss = math.inf + zd_lowest_loss = math.inf + xd_lowest_loss = math.inf + grace_period = 3 + for epoch in range(train_epoch): # define loss variables encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) @@ -205,16 +213,22 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i epoch_end_time = time.time() per_epoch_time = epoch_end_time - epoch_start_time - + + # final losses of epoch + decoder_loss = decoder_loss_avg.result(False) + encoder_loss = encoder_loss_avg.result(False) + enc_dec_loss = enc_dec_loss_avg.result(False) + xd_loss = xd_loss_avg.result(False) + zd_loss = zd_loss_avg.result(False) if verbose: print(( f"[{epoch + 1:d}/{train_epoch:d}] - " f"train time: {per_epoch_time:.2f}, " - f"Decoder loss: {decoder_loss_avg.result(False)}, " - f"X Discriminator loss: {xd_loss_avg.result(False):.3f}, " - f"Z Discriminator loss: {zd_loss_avg.result(False):.3f}, " - f"Encoder + Decoder loss: {enc_dec_loss_avg.result(False):.3f}, " - f"Encoder loss: {encoder_loss_avg.result(False):.3f}" + f"Decoder loss: {decoder_loss}, " + f"X Discriminator loss: {xd_loss:.3f}, " + f"Z Discriminator loss: {zd_loss:.3f}, " + f"Encoder + Decoder loss: {enc_dec_loss:.3f}, " + f"Encoder loss: {encoder_loss:.3f}" )) # save sample image @@ -229,9 +243,43 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i ndarr = grid.cpu().numpy() im = Image.fromarray(ndarr) im.save(filename) + + # check for improvements in error reduction - otherwise early stopping + strike = False + if encoder_loss < encoder_lowest_loss: + encoder_lowest_loss = encoder_loss + else: + strike = True + if decoder_loss < decoder_lowest_loss: + decoder_lowest_loss = decoder_loss + else: + strike = True + if enc_dec_loss < enc_dec_lowest_loss: + enc_dec_lowest_loss = enc_dec_loss + else: + strike = True + if xd_loss < xd_lowest_loss: + xd_lowest_loss = xd_loss + else: + strike = True + if zd_loss < zd_lowest_loss: + zd_lowest_loss = zd_loss + else: + strike = True + + if strike: + grace_period -= 1 + else: + grace_period = 3 + + if grace_period == 0: + break if verbose: - print("Training finish!... save training results") + if grace_period > 0: + print("Training finish!... save model weights") + if grace_period == 0: + print("Training stopped early!... save model weights") # save trained models encoder.save_weights("./weights/encoder/" + str(inlier_classes[0])) @@ -259,13 +307,13 @@ def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder, :return: the calculated loss """ with tf.GradientTape() as tape: - xd_result = tf.squeeze(x_discriminator(inputs)) - xd_real_loss = binary_crossentropy(targets_real, xd_result) + xd_result_1 = tf.squeeze(x_discriminator(inputs)) + xd_real_loss = binary_crossentropy(targets_real, xd_result_1) z = z_generator() x_fake = decoder(z) - xd_result = tf.squeeze(x_discriminator(x_fake)) - xd_fake_loss = binary_crossentropy(targets_fake, xd_result) + xd_result_2 = tf.squeeze(x_discriminator(x_fake)) + xd_fake_loss = binary_crossentropy(targets_fake, xd_result_2) _xd_train_loss = xd_real_loss + xd_fake_loss @@ -414,7 +462,7 @@ def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable: if __name__ == "__main__": tf.enable_eager_execution() inlier_classes = [0] - iteration = 2 + iteration = 1 train_summary_writer = summary_ops_v2.create_file_writer( './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():