Adjusted predict function to new checkpoint design

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-06-10 11:17:17 +02:00
parent 110e098d78
commit 4b85dd8376

View File

@ -171,14 +171,24 @@ def predict(dataset: tf.data.Dataset,
if weights_path is None and checkpoint_path is None: if weights_path is None and checkpoint_path is None:
raise ValueError("Either 'weights_path' or 'checkpoint_path' must be given.") raise ValueError("Either 'weights_path' or 'checkpoint_path' must be given.")
checkpointables = {} # model
if use_dropout: if use_dropout:
checkpointables.update({ ssd = DropoutSSD(mode='training', weights_path=weights_path)
'ssd': DropoutSSD(mode='training', weights_path=weights_path)
})
else: else:
ssd = SSD(mode='training', weights_path=weights_path)
checkpointables = {
'ssd': ssd.model,
'learning_rate_var': K.variable(0),
}
checkpointables.update({ checkpointables.update({
'ssd': SSD(mode='inference_fast', weights_path=weights_path) # optimizer
'ssd_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
beta1=0.5, beta2=0.999),
# global step counter
'global_step': tf.train.get_or_create_global_step(),
'epoch_var': K.variable(-1, dtype=tf.int64)
}) })
if checkpoint_path is not None: if checkpoint_path is not None:
@ -188,7 +198,7 @@ def predict(dataset: tf.data.Dataset,
checkpoint.restore(latest_checkpoint) checkpoint.restore(latest_checkpoint)
outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image, outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image,
nr_digits, **checkpointables) nr_digits, checkpointables['ssd'])
if verbose: if verbose:
print(( print((