Adjusted predict function to new checkpoint design
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -171,14 +171,24 @@ def predict(dataset: tf.data.Dataset,
|
|||||||
if weights_path is None and checkpoint_path is None:
|
if weights_path is None and checkpoint_path is None:
|
||||||
raise ValueError("Either 'weights_path' or 'checkpoint_path' must be given.")
|
raise ValueError("Either 'weights_path' or 'checkpoint_path' must be given.")
|
||||||
|
|
||||||
checkpointables = {}
|
# model
|
||||||
if use_dropout:
|
if use_dropout:
|
||||||
checkpointables.update({
|
ssd = DropoutSSD(mode='training', weights_path=weights_path)
|
||||||
'ssd': DropoutSSD(mode='training', weights_path=weights_path)
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
|
ssd = SSD(mode='training', weights_path=weights_path)
|
||||||
|
|
||||||
|
checkpointables = {
|
||||||
|
'ssd': ssd.model,
|
||||||
|
'learning_rate_var': K.variable(0),
|
||||||
|
}
|
||||||
|
|
||||||
checkpointables.update({
|
checkpointables.update({
|
||||||
'ssd': SSD(mode='inference_fast', weights_path=weights_path)
|
# optimizer
|
||||||
|
'ssd_optimizer': tf.train.AdamOptimizer(learning_rate=checkpointables['learning_rate_var'],
|
||||||
|
beta1=0.5, beta2=0.999),
|
||||||
|
# global step counter
|
||||||
|
'global_step': tf.train.get_or_create_global_step(),
|
||||||
|
'epoch_var': K.variable(-1, dtype=tf.int64)
|
||||||
})
|
})
|
||||||
|
|
||||||
if checkpoint_path is not None:
|
if checkpoint_path is not None:
|
||||||
@ -188,7 +198,7 @@ def predict(dataset: tf.data.Dataset,
|
|||||||
checkpoint.restore(latest_checkpoint)
|
checkpoint.restore(latest_checkpoint)
|
||||||
|
|
||||||
outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image,
|
outputs = _predict_one_epoch(dataset, use_dropout, output_path, forward_passes_per_image,
|
||||||
nr_digits, **checkpointables)
|
nr_digits, checkpointables['ssd'])
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print((
|
print((
|
||||||
|
|||||||
Reference in New Issue
Block a user