Added prediction functionality to ssd module
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -31,6 +31,7 @@ Classes:
|
||||
``SSD``: wraps vanilla SSD 300 model
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
@ -99,6 +100,100 @@ class DropoutSSD:
|
||||
return self._model(inputs)
|
||||
|
||||
|
||||
def predict(dataset: tf.data.Dataset,
|
||||
use_dropout: bool,
|
||||
output_path: str,
|
||||
weights_path: Optional[str] = None,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
forward_passes_per_image: Optional[int] = 42) -> None:
|
||||
"""
|
||||
Run trained SSD on the given data set.
|
||||
|
||||
Either the weights path or the checkpoint path must be given. This prevents
|
||||
a scenario where an untrained network is used to predict.
|
||||
|
||||
The prediction results are saved to the output path.
|
||||
|
||||
Args:
|
||||
dataset: the testing data set
|
||||
use_dropout: if True, DropoutSSD will be used
|
||||
output_path: the path in which the results should be saved
|
||||
weights_path: the path to the trained Keras weights (h5 file)
|
||||
checkpoint_path: the path to the stored checkpoints (Tensorflow checkpoints)
|
||||
verbose: if True, progress is printed to the standard output
|
||||
forward_passes_per_image: specifies number of forward passes per image
|
||||
used by DropoutSSD
|
||||
"""
|
||||
if weights_path is None and checkpoint_path is None:
|
||||
raise ValueError("Either 'weights_path' or 'checkpoint_path' must be given.")
|
||||
|
||||
checkpointables = {}
|
||||
if use_dropout:
|
||||
checkpointables.update({
|
||||
'ssd': DropoutSSD(mode='inference_fast', weights_path=weights_path)
|
||||
})
|
||||
else:
|
||||
checkpointables.update({
|
||||
'ssd': SSD(mode='inference_fast', weights_path=weights_path)
|
||||
})
|
||||
|
||||
if checkpoint_path is not None:
|
||||
# checkpoint
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||
checkpoint = tf.train.Checkpoint(**checkpointables)
|
||||
checkpoint.restore(latest_checkpoint)
|
||||
|
||||
outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables)
|
||||
|
||||
# save predictions
|
||||
filename = 'ssd_predictions.bin'
|
||||
if use_dropout:
|
||||
filename = 'dropout-' + filename
|
||||
|
||||
with open(output_path + filename, 'wb') as file:
|
||||
pickle.dump(outputs, file)
|
||||
|
||||
if verbose:
|
||||
print((
|
||||
f"predict time: {outputs['per_epoch_time']:.2f}, "
|
||||
))
|
||||
print("Prediction finished!... save outputs")
|
||||
|
||||
|
||||
def _predict_one_epoch(dataset: tf.data.Dataset,
|
||||
use_dropout: bool,
|
||||
forward_passes_per_image: int,
|
||||
ssd: tf.keras.Model) -> Dict[str, float]:
|
||||
|
||||
with summary_ops_v2.always_record_summaries():
|
||||
epoch_start_time = time.time()
|
||||
decoded_predictions = []
|
||||
|
||||
# go through the data set
|
||||
for inputs, _ in dataset:
|
||||
decoded_predictions_batch = []
|
||||
if use_dropout:
|
||||
for _ in range(forward_passes_per_image):
|
||||
decoded_predictions_pass = ssd(inputs)
|
||||
decoded_predictions_batch.append(decoded_predictions_pass)
|
||||
else:
|
||||
decoded_predictions_batch.append(ssd(inputs))
|
||||
|
||||
decoded_predictions.append(decoded_predictions_batch)
|
||||
|
||||
epoch_end_time = time.time()
|
||||
per_epoch_time = epoch_end_time - epoch_start_time
|
||||
|
||||
# outputs for epoch
|
||||
outputs = {
|
||||
'per_epoch_time': per_epoch_time,
|
||||
'decoded_predictions': decoded_predictions
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def train(dataset: tf.data.Dataset,
|
||||
iteration: int,
|
||||
use_dropout: bool,
|
||||
|
||||
Reference in New Issue
Block a user