Prevented unnecessary early stops
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -35,6 +35,8 @@ AdamOptimizer = tf.train.AdamOptimizer
|
|||||||
tfe = tf.contrib.eager
|
tfe = tf.contrib.eager
|
||||||
binary_crossentropy = tf.keras.losses.binary_crossentropy
|
binary_crossentropy = tf.keras.losses.binary_crossentropy
|
||||||
|
|
||||||
|
GRACE = 10
|
||||||
|
|
||||||
|
|
||||||
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: int,
|
||||||
iteration: int,
|
iteration: int,
|
||||||
@ -123,7 +125,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
zd_lowest_loss = math.inf
|
zd_lowest_loss = math.inf
|
||||||
xd_lowest_loss = math.inf
|
xd_lowest_loss = math.inf
|
||||||
total_lowest_loss = math.inf
|
total_lowest_loss = math.inf
|
||||||
grace_period = 3
|
grace_period = GRACE
|
||||||
|
|
||||||
for epoch in range(train_epoch):
|
for epoch in range(train_epoch):
|
||||||
# define loss variables
|
# define loss variables
|
||||||
@ -253,7 +255,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss
|
total_loss = encoder_loss + decoder_loss + enc_dec_loss + xd_loss + zd_loss
|
||||||
if total_loss < total_lowest_loss:
|
if total_loss < total_lowest_loss:
|
||||||
total_lowest_loss = total_loss
|
total_lowest_loss = total_loss
|
||||||
else:
|
elif total_loss > 6:
|
||||||
total_strike = True
|
total_strike = True
|
||||||
if encoder_loss < encoder_lowest_loss:
|
if encoder_loss < encoder_lowest_loss:
|
||||||
encoder_lowest_loss = encoder_loss
|
encoder_lowest_loss = encoder_loss
|
||||||
@ -281,7 +283,7 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
|
|||||||
elif strike:
|
elif strike:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
grace_period = 3
|
grace_period = GRACE
|
||||||
|
|
||||||
if grace_period == 0:
|
if grace_period == 0:
|
||||||
break
|
break
|
||||||
|
|||||||
Reference in New Issue
Block a user