diff --git a/src/twomartens/masterthesis/aae/train.py b/src/twomartens/masterthesis/aae/train.py index 5bdc2f8..bdca3b3 100644 --- a/src/twomartens/masterthesis/aae/train.py +++ b/src/twomartens/masterthesis/aae/train.py @@ -120,6 +120,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i enc_dec_lowest_loss = math.inf zd_lowest_loss = math.inf xd_lowest_loss = math.inf + total_lowest_loss = math.inf grace_period = 3 for epoch in range(train_epoch): @@ -246,6 +247,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i # check for improvements in error reduction - otherwise early stopping strike = False + total_strike = False + total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss + if total_loss < total_lowest_loss: + total_lowest_loss = total_loss + else: + total_strike = True if encoder_loss < encoder_lowest_loss: encoder_lowest_loss = encoder_loss else: @@ -267,8 +274,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i else: strike = True - if strike: + if strike and total_strike: grace_period -= 1 + elif strike: + pass else: grace_period = 3