diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index d7acd39..d74405d 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -35,6 +35,8 @@ AdamOptimizer = tf.train.AdamOptimizer tfe = tf.contrib.eager binary_crossentropy = tf.keras.losses.binary_crossentropy +GRACE = 10 + def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int, iteration: int, @@ -123,7 +125,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i zd_lowest_loss = math.inf xd_lowest_loss = math.inf total_lowest_loss = math.inf - grace_period = 3 + grace_period = GRACE for epoch in range(train_epoch): # define loss variables @@ -253,7 +255,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss if total_loss < total_lowest_loss: total_lowest_loss = total_loss - else: + elif total_loss > 6: total_strike = True if encoder_loss < encoder_lowest_loss: encoder_lowest_loss = encoder_loss @@ -281,7 +283,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i elif strike: pass else: - grace_period = 3 + grace_period = GRACE if grace_period == 0: break