Extended early stopping with total loss

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 13:27:19 +01:00
parent c19ff19820
commit 3bde913323

View File

@ -120,6 +120,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
enc_dec_lowest_loss = math.inf
zd_lowest_loss = math.inf
xd_lowest_loss = math.inf
total_lowest_loss = math.inf
grace_period = 3
for epoch in range(train_epoch):
@ -246,6 +247,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
# check for improvements in error reduction - otherwise early stopping
strike = False
total_strike = False
total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss
if total_loss < total_lowest_loss:
total_lowest_loss = total_loss
else:
total_strike = True
if encoder_loss < encoder_lowest_loss:
encoder_lowest_loss = encoder_loss
else:
@ -267,8 +274,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
else:
strike = True
if strike:
if strike and total_strike:
grace_period -= 1
elif strike:
pass
else:
grace_period = 3