Implemented early stopping
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -16,6 +16,7 @@
|
||||
|
||||
"""aae.train.py: contains training functionality"""
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
@ -114,6 +115,13 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
global_step_xd = k.variable(0, dtype=tf.int64)
|
||||
global_step_zd = k.variable(0, dtype=tf.int64)
|
||||
|
||||
encoder_lowest_loss = math.inf
|
||||
decoder_lowest_loss = math.inf
|
||||
enc_dec_lowest_loss = math.inf
|
||||
zd_lowest_loss = math.inf
|
||||
xd_lowest_loss = math.inf
|
||||
grace_period = 3
|
||||
|
||||
for epoch in range(train_epoch):
|
||||
# define loss variables
|
||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||
@ -206,15 +214,21 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
epoch_end_time = time.time()
|
||||
per_epoch_time = epoch_end_time - epoch_start_time
|
||||
|
||||
# final losses of epoch
|
||||
decoder_loss = decoder_loss_avg.result(False)
|
||||
encoder_loss = encoder_loss_avg.result(False)
|
||||
enc_dec_loss = enc_dec_loss_avg.result(False)
|
||||
xd_loss = xd_loss_avg.result(False)
|
||||
zd_loss = zd_loss_avg.result(False)
|
||||
if verbose:
|
||||
print((
|
||||
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
||||
f"train time: {per_epoch_time:.2f}, "
|
||||
f"Decoder loss: {decoder_loss_avg.result(False)}, "
|
||||
f"X Discriminator loss: {xd_loss_avg.result(False):.3f}, "
|
||||
f"Z Discriminator loss: {zd_loss_avg.result(False):.3f}, "
|
||||
f"Encoder + Decoder loss: {enc_dec_loss_avg.result(False):.3f}, "
|
||||
f"Encoder loss: {encoder_loss_avg.result(False):.3f}"
|
||||
f"Decoder loss: {decoder_loss}, "
|
||||
f"X Discriminator loss: {xd_loss:.3f}, "
|
||||
f"Z Discriminator loss: {zd_loss:.3f}, "
|
||||
f"Encoder + Decoder loss: {enc_dec_loss:.3f}, "
|
||||
f"Encoder loss: {encoder_loss:.3f}"
|
||||
))
|
||||
|
||||
# save sample image
|
||||
@ -230,8 +244,42 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(filename)
|
||||
|
||||
# check for improvements in error reduction - otherwise early stopping
|
||||
strike = False
|
||||
if encoder_loss < encoder_lowest_loss:
|
||||
encoder_lowest_loss = encoder_loss
|
||||
else:
|
||||
strike = True
|
||||
if decoder_loss < decoder_lowest_loss:
|
||||
decoder_lowest_loss = decoder_loss
|
||||
else:
|
||||
strike = True
|
||||
if enc_dec_loss < enc_dec_lowest_loss:
|
||||
enc_dec_lowest_loss = enc_dec_loss
|
||||
else:
|
||||
strike = True
|
||||
if xd_loss < xd_lowest_loss:
|
||||
xd_lowest_loss = xd_loss
|
||||
else:
|
||||
strike = True
|
||||
if zd_loss < zd_lowest_loss:
|
||||
zd_lowest_loss = zd_loss
|
||||
else:
|
||||
strike = True
|
||||
|
||||
if strike:
|
||||
grace_period -= 1
|
||||
else:
|
||||
grace_period = 3
|
||||
|
||||
if grace_period == 0:
|
||||
break
|
||||
|
||||
if verbose:
|
||||
print("Training finish!... save training results")
|
||||
if grace_period > 0:
|
||||
print("Training finish!... save model weights")
|
||||
if grace_period == 0:
|
||||
print("Training stopped early!... save model weights")
|
||||
|
||||
# save trained models
|
||||
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]))
|
||||
@ -259,13 +307,13 @@ def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
||||
:return: the calculated loss
|
||||
"""
|
||||
with tf.GradientTape() as tape:
|
||||
xd_result = tf.squeeze(x_discriminator(inputs))
|
||||
xd_real_loss = binary_crossentropy(targets_real, xd_result)
|
||||
xd_result_1 = tf.squeeze(x_discriminator(inputs))
|
||||
xd_real_loss = binary_crossentropy(targets_real, xd_result_1)
|
||||
|
||||
z = z_generator()
|
||||
x_fake = decoder(z)
|
||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
||||
xd_fake_loss = binary_crossentropy(targets_fake, xd_result)
|
||||
xd_result_2 = tf.squeeze(x_discriminator(x_fake))
|
||||
xd_fake_loss = binary_crossentropy(targets_fake, xd_result_2)
|
||||
|
||||
_xd_train_loss = xd_real_loss + xd_fake_loss
|
||||
|
||||
@ -414,7 +462,7 @@ def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable:
|
||||
if __name__ == "__main__":
|
||||
tf.enable_eager_execution()
|
||||
inlier_classes = [0]
|
||||
iteration = 2
|
||||
iteration = 1
|
||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
||||
|
||||
Reference in New Issue
Block a user