Extended early stopping with total loss
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -120,6 +120,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
enc_dec_lowest_loss = math.inf
|
||||
zd_lowest_loss = math.inf
|
||||
xd_lowest_loss = math.inf
|
||||
total_lowest_loss = math.inf
|
||||
grace_period = 3
|
||||
|
||||
for epoch in range(train_epoch):
|
||||
@ -246,6 +247,12 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
|
||||
# check for improvements in error reduction - otherwise early stopping
|
||||
strike = False
|
||||
total_strike = False
|
||||
total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss
|
||||
if total_loss < total_lowest_loss:
|
||||
total_lowest_loss = total_loss
|
||||
else:
|
||||
total_strike = True
|
||||
if encoder_loss < encoder_lowest_loss:
|
||||
encoder_lowest_loss = encoder_loss
|
||||
else:
|
||||
@ -267,8 +274,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
||||
else:
|
||||
strike = True
|
||||
|
||||
if strike:
|
||||
if strike and total_strike:
|
||||
grace_period -= 1
|
||||
elif strike:
|
||||
pass
|
||||
else:
|
||||
grace_period = 3
|
||||
|
||||
|
||||
Reference in New Issue
Block a user