From ec9b64a75722c63a3fdeec8e00f8bf8b947643f1 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 13:51:41 +0100 Subject: [PATCH] Prevented unnecessary early stops Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index d7acd39..d74405d 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -35,6 +35,8 @@ AdamOptimizer = tf.train.AdamOptimizer tfe = tf.contrib.eager binary_crossentropy = tf.keras.losses.binary_crossentropy +GRACE = 10 + def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int, iteration: int, @@ -123,7 +125,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i zd_lowest_loss = math.inf xd_lowest_loss = math.inf total_lowest_loss = math.inf - grace_period = 3 + grace_period = GRACE for epoch in range(train_epoch): # define loss variables @@ -253,7 +255,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss if total_loss < total_lowest_loss: total_lowest_loss = total_loss - else: + elif total_loss > 6: total_strike = True if encoder_loss < encoder_lowest_loss: encoder_lowest_loss = encoder_loss @@ -281,7 +283,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i elif strike: pass else: - grace_period = 3 + grace_period = GRACE if grace_period == 0: break