Implemented early stopping

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 13:20:17 +01:00
parent 682e11b435
commit c19ff19820

View File

@ -16,6 +16,7 @@
"""aae.train.py: contains training functionality""" """aae.train.py: contains training functionality"""
import functools import functools
import math
import os import os
import pickle import pickle
import time import time
@ -114,6 +115,13 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
global_step_xd = k.variable(0, dtype=tf.int64) global_step_xd = k.variable(0, dtype=tf.int64)
global_step_zd = k.variable(0, dtype=tf.int64) global_step_zd = k.variable(0, dtype=tf.int64)
encoder_lowest_loss = math.inf
decoder_lowest_loss = math.inf
enc_dec_lowest_loss = math.inf
zd_lowest_loss = math.inf
xd_lowest_loss = math.inf
grace_period = 3
for epoch in range(train_epoch): for epoch in range(train_epoch):
# define loss variables # define loss variables
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32) encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
@ -206,15 +214,21 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
epoch_end_time = time.time() epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time per_epoch_time = epoch_end_time - epoch_start_time
# final losses of epoch
decoder_loss = decoder_loss_avg.result(False)
encoder_loss = encoder_loss_avg.result(False)
enc_dec_loss = enc_dec_loss_avg.result(False)
xd_loss = xd_loss_avg.result(False)
zd_loss = zd_loss_avg.result(False)
if verbose: if verbose:
print(( print((
f"[{epoch + 1:d}/{train_epoch:d}] - " f"[{epoch + 1:d}/{train_epoch:d}] - "
f"train time: {per_epoch_time:.2f}, " f"train time: {per_epoch_time:.2f}, "
f"Decoder loss: {decoder_loss_avg.result(False)}, " f"Decoder loss: {decoder_loss}, "
f"X Discriminator loss: {xd_loss_avg.result(False):.3f}, " f"X Discriminator loss: {xd_loss:.3f}, "
f"Z Discriminator loss: {zd_loss_avg.result(False):.3f}, " f"Z Discriminator loss: {zd_loss:.3f}, "
f"Encoder + Decoder loss: {enc_dec_loss_avg.result(False):.3f}, " f"Encoder + Decoder loss: {enc_dec_loss:.3f}, "
f"Encoder loss: {encoder_loss_avg.result(False):.3f}" f"Encoder loss: {encoder_loss:.3f}"
)) ))
# save sample image # save sample image
@ -230,8 +244,42 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
im = Image.fromarray(ndarr) im = Image.fromarray(ndarr)
im.save(filename) im.save(filename)
# check for improvements in error reduction - otherwise early stopping
strike = False
if encoder_loss < encoder_lowest_loss:
encoder_lowest_loss = encoder_loss
else:
strike = True
if decoder_loss < decoder_lowest_loss:
decoder_lowest_loss = decoder_loss
else:
strike = True
if enc_dec_loss < enc_dec_lowest_loss:
enc_dec_lowest_loss = enc_dec_loss
else:
strike = True
if xd_loss < xd_lowest_loss:
xd_lowest_loss = xd_loss
else:
strike = True
if zd_loss < zd_lowest_loss:
zd_lowest_loss = zd_loss
else:
strike = True
if strike:
grace_period -= 1
else:
grace_period = 3
if grace_period == 0:
break
if verbose: if verbose:
print("Training finish!... save training results") if grace_period > 0:
print("Training finish!... save model weights")
if grace_period == 0:
print("Training stopped early!... save model weights")
# save trained models # save trained models
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0])) encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]))
@ -259,13 +307,13 @@ def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
:return: the calculated loss :return: the calculated loss
""" """
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
xd_result = tf.squeeze(x_discriminator(inputs)) xd_result_1 = tf.squeeze(x_discriminator(inputs))
xd_real_loss = binary_crossentropy(targets_real, xd_result) xd_real_loss = binary_crossentropy(targets_real, xd_result_1)
z = z_generator() z = z_generator()
x_fake = decoder(z) x_fake = decoder(z)
xd_result = tf.squeeze(x_discriminator(x_fake)) xd_result_2 = tf.squeeze(x_discriminator(x_fake))
xd_fake_loss = binary_crossentropy(targets_fake, xd_result) xd_fake_loss = binary_crossentropy(targets_fake, xd_result_2)
_xd_train_loss = xd_real_loss + xd_fake_loss _xd_train_loss = xd_real_loss + xd_fake_loss
@ -414,7 +462,7 @@ def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable:
if __name__ == "__main__": if __name__ == "__main__":
tf.enable_eager_execution() tf.enable_eager_execution()
inlier_classes = [0] inlier_classes = [0]
iteration = 2 iteration = 1
train_summary_writer = summary_ops_v2.create_file_writer( train_summary_writer = summary_ops_v2.create_file_writer(
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration)) './summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries(): with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():