Prevented unnecessary early stops

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 13:51:41 +01:00
parent 948ef64423
commit ec9b64a757

View File

@ -35,6 +35,8 @@ AdamOptimizer = tf.train.AdamOptimizer
tfe = tf.contrib.eager tfe = tf.contrib.eager
binary_crossentropy = tf.keras.losses.binary_crossentropy binary_crossentropy = tf.keras.losses.binary_crossentropy
GRACE = 10
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int, def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
iteration: int, iteration: int,
@ -123,7 +125,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
zd_lowest_loss = math.inf zd_lowest_loss = math.inf
xd_lowest_loss = math.inf xd_lowest_loss = math.inf
total_lowest_loss = math.inf total_lowest_loss = math.inf
grace_period = 3 grace_period = GRACE
for epoch in range(train_epoch): for epoch in range(train_epoch):
# define loss variables # define loss variables
@ -253,7 +255,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss
if total_loss < total_lowest_loss: if total_loss < total_lowest_loss:
total_lowest_loss = total_loss total_lowest_loss = total_loss
else: elif total_loss > 6:
total_strike = True total_strike = True
if encoder_loss < encoder_lowest_loss: if encoder_loss < encoder_lowest_loss:
encoder_lowest_loss = encoder_loss encoder_lowest_loss = encoder_loss
@ -281,7 +283,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
elif strike: elif strike:
pass pass
else: else:
grace_period = 3 grace_period = GRACE
if grace_period == 0: if grace_period == 0:
break break