Fixed specification of checkpoint dir

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-02-08 14:59:50 +01:00
parent 720ff63c7d
commit 279e848434

View File

@ -127,8 +127,10 @@ def train_mnist(folding_id: int, inlier_classes: Sequence[int], total_classes: i
total_lowest_loss = math.inf total_lowest_loss = math.inf
grace_period = GRACE grace_period = GRACE
checkpoint_prefix = './weights/' + str(inlier_classes[0]) + '/' + str(iteration) + '/ckpt' checkpoint_dir = os.path.abspath('./weights/' + str(inlier_classes[0]) + '/' + str(iteration) + '/')
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_prefix) os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint = tf.train.Checkpoint(encoder=encoder, checkpoint = tf.train.Checkpoint(encoder=encoder,
decoder=decoder, decoder=decoder,
z_discriminator=z_discriminator, z_discriminator=z_discriminator,