From 3bde9133235e3eb930848b297506847ceb0d0596 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Fri, 8 Feb 2019 13:27:19 +0100 Subject: [PATCH] Extended early stopping with total loss Signed-off-by: Jim Martens --- src/twomartens/masterthesis/aae/train.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 5bdc2f8..bdca3b3 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -120,6 +120,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i enc_dec_lowest_loss = math.inf zd_lowest_loss = math.inf xd_lowest_loss = math.inf + total_lowest_loss = math.inf grace_period = 3 for epoch in range(train_epoch): @@ -246,6 +247,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i # check for improvements in error reduction - otherwise early stopping strike = False + total_strike = False + total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss + if total_loss < total_lowest_loss: + total_lowest_loss = total_loss + else: + total_strike = True if encoder_loss < encoder_lowest_loss: encoder_lowest_loss = encoder_loss else: @@ -267,8 +274,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i else: strike = True - if strike: + if strike and total_strike: grace_period -= 1 + elif strike: + pass else: grace_period = 3