Implemented early stopping
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
"""aae.train.py: contains training functionality"""
|
"""aae.train.py: contains training functionality"""
|
||||||
import functools
|
import functools
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
@ -114,6 +115,13 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
global_step_xd = k.variable(0, dtype=tf.int64)
|
global_step_xd = k.variable(0, dtype=tf.int64)
|
||||||
global_step_zd = k.variable(0, dtype=tf.int64)
|
global_step_zd = k.variable(0, dtype=tf.int64)
|
||||||
|
|
||||||
|
encoder_lowest_loss = math.inf
|
||||||
|
decoder_lowest_loss = math.inf
|
||||||
|
enc_dec_lowest_loss = math.inf
|
||||||
|
zd_lowest_loss = math.inf
|
||||||
|
xd_lowest_loss = math.inf
|
||||||
|
grace_period = 3
|
||||||
|
|
||||||
for epoch in range(train_epoch):
|
for epoch in range(train_epoch):
|
||||||
# define loss variables
|
# define loss variables
|
||||||
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
encoder_loss_avg = tfe.metrics.Mean(name='encoder_loss', dtype=tf.float32)
|
||||||
@ -205,16 +213,22 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
|
|
||||||
epoch_end_time = time.time()
|
epoch_end_time = time.time()
|
||||||
per_epoch_time = epoch_end_time - epoch_start_time
|
per_epoch_time = epoch_end_time - epoch_start_time
|
||||||
|
|
||||||
|
# final losses of epoch
|
||||||
|
decoder_loss = decoder_loss_avg.result(False)
|
||||||
|
encoder_loss = encoder_loss_avg.result(False)
|
||||||
|
enc_dec_loss = enc_dec_loss_avg.result(False)
|
||||||
|
xd_loss = xd_loss_avg.result(False)
|
||||||
|
zd_loss = zd_loss_avg.result(False)
|
||||||
if verbose:
|
if verbose:
|
||||||
print((
|
print((
|
||||||
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
f"[{epoch + 1:d}/{train_epoch:d}] - "
|
||||||
f"train time: {per_epoch_time:.2f}, "
|
f"train time: {per_epoch_time:.2f}, "
|
||||||
f"Decoder loss: {decoder_loss_avg.result(False)}, "
|
f"Decoder loss: {decoder_loss}, "
|
||||||
f"X Discriminator loss: {xd_loss_avg.result(False):.3f}, "
|
f"X Discriminator loss: {xd_loss:.3f}, "
|
||||||
f"Z Discriminator loss: {zd_loss_avg.result(False):.3f}, "
|
f"Z Discriminator loss: {zd_loss:.3f}, "
|
||||||
f"Encoder + Decoder loss: {enc_dec_loss_avg.result(False):.3f}, "
|
f"Encoder + Decoder loss: {enc_dec_loss:.3f}, "
|
||||||
f"Encoder loss: {encoder_loss_avg.result(False):.3f}"
|
f"Encoder loss: {encoder_loss:.3f}"
|
||||||
))
|
))
|
||||||
|
|
||||||
# save sample image
|
# save sample image
|
||||||
@ -229,9 +243,43 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
ndarr = grid.cpu().numpy()
|
ndarr = grid.cpu().numpy()
|
||||||
im = Image.fromarray(ndarr)
|
im = Image.fromarray(ndarr)
|
||||||
im.save(filename)
|
im.save(filename)
|
||||||
|
|
||||||
|
# check for improvements in error reduction - otherwise early stopping
|
||||||
|
strike = False
|
||||||
|
if encoder_loss < encoder_lowest_loss:
|
||||||
|
encoder_lowest_loss = encoder_loss
|
||||||
|
else:
|
||||||
|
strike = True
|
||||||
|
if decoder_loss < decoder_lowest_loss:
|
||||||
|
decoder_lowest_loss = decoder_loss
|
||||||
|
else:
|
||||||
|
strike = True
|
||||||
|
if enc_dec_loss < enc_dec_lowest_loss:
|
||||||
|
enc_dec_lowest_loss = enc_dec_loss
|
||||||
|
else:
|
||||||
|
strike = True
|
||||||
|
if xd_loss < xd_lowest_loss:
|
||||||
|
xd_lowest_loss = xd_loss
|
||||||
|
else:
|
||||||
|
strike = True
|
||||||
|
if zd_loss < zd_lowest_loss:
|
||||||
|
zd_lowest_loss = zd_loss
|
||||||
|
else:
|
||||||
|
strike = True
|
||||||
|
|
||||||
|
if strike:
|
||||||
|
grace_period -= 1
|
||||||
|
else:
|
||||||
|
grace_period = 3
|
||||||
|
|
||||||
|
if grace_period == 0:
|
||||||
|
break
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Training finish!... save training results")
|
if grace_period > 0:
|
||||||
|
print("Training finish!... save model weights")
|
||||||
|
if grace_period == 0:
|
||||||
|
print("Training stopped early!... save model weights")
|
||||||
|
|
||||||
# save trained models
|
# save trained models
|
||||||
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]))
|
encoder.save_weights("./weights/encoder/" + str(inlier_classes[0]))
|
||||||
@ -259,13 +307,13 @@ def train_xdiscriminator_step(x_discriminator: XDiscriminator, decoder: Decoder,
|
|||||||
:return: the calculated loss
|
:return: the calculated loss
|
||||||
"""
|
"""
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
xd_result = tf.squeeze(x_discriminator(inputs))
|
xd_result_1 = tf.squeeze(x_discriminator(inputs))
|
||||||
xd_real_loss = binary_crossentropy(targets_real, xd_result)
|
xd_real_loss = binary_crossentropy(targets_real, xd_result_1)
|
||||||
|
|
||||||
z = z_generator()
|
z = z_generator()
|
||||||
x_fake = decoder(z)
|
x_fake = decoder(z)
|
||||||
xd_result = tf.squeeze(x_discriminator(x_fake))
|
xd_result_2 = tf.squeeze(x_discriminator(x_fake))
|
||||||
xd_fake_loss = binary_crossentropy(targets_fake, xd_result)
|
xd_fake_loss = binary_crossentropy(targets_fake, xd_result_2)
|
||||||
|
|
||||||
_xd_train_loss = xd_real_loss + xd_fake_loss
|
_xd_train_loss = xd_real_loss + xd_fake_loss
|
||||||
|
|
||||||
@ -414,7 +462,7 @@ def extract_batch(data: np.ndarray, it: int, batch_size: int) -> tfe.Variable:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
inlier_classes = [0]
|
inlier_classes = [0]
|
||||||
iteration = 2
|
iteration = 1
|
||||||
train_summary_writer = summary_ops_v2.create_file_writer(
|
train_summary_writer = summary_ops_v2.create_file_writer(
|
||||||
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
'./summaries/train/number-' + str(inlier_classes[0]) + '/' + str(iteration))
|
||||||
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
with train_summary_writer.as_default(), summary_ops_v2.always_record_summaries():
|
||||||
|
|||||||
Reference in New Issue
Block a user