Adjusted predict function to new checkpoint design
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -171,15 +171,25 @@ def predict(dataset: tf.data.Dataset,
|
||||
if weights_path is None and checkpoint_path is None:
|
||||
raise ValueError("Either 'weights_path' or 'checkpoint_path' must be given.")
|
||||
|
||||
checkpointables = {}
|
||||
# model
|
||||
if use_dropout:
|
||||
checkpointables.update({
|
||||
'ssd': DropoutSSD(mode='training', weights_path=weights_path)
|
||||
})
|
||||
ssd = DropoutSSD(mode='training', weights_path=weights_path)
|
||||
else:
|
||||
checkpointables.update({
|
||||
'ssd': SSD(mode='inference_fast', weights_path=weights_path)
|
||||
})
|
||||
ssd = SSD(mode='training', weights_path=weights_path)
|
||||
|
||||
checkpointables = {
|
||||
'ssd': ssd.model,
|
||||
'learning_rate_var': K.variable(0),
|
||||
}
|
||||
|
||||
checkpointables.update({
|
||||
# optimizer
|
||||
'ssd_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||
beta1=0.5, beta2=0.999),
|
||||
# global step counter
|
||||
'global_step': tf.train.get_or_create_global_step(),
|
||||
'epoch_var': K.variable(-1, dtype=tf.int64)
|
||||
})
|
||||
|
||||
if checkpoint_path is not None:
|
||||
# checkpoint
|
||||
@ -188,7 +198,7 @@ def predict(dataset: tf.data.Dataset,
|
||||
checkpoint.restore(latest_checkpoint)
|
||||
|
||||
outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image,
|
||||
nr_digits, **checkpointables)
|
||||
nr_digits, checkpointables['ssd'])
|
||||
|
||||
if verbose:
|
||||
print((
|
||||
|
||||
Reference in New Issue
Block a user