Added prediction functionality to ssd module

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-03-21 17:12:17 +01:00
parent f40d04e34c
commit 9eb1283b8b

View File

@ -31,6 +31,7 @@ Classes:
``SSD``: wraps vanilla SSD 300 model
"""
import os
import pickle
import time
from typing import Dict
from typing import Optional
@ -97,6 +98,100 @@ class DropoutSSD:
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
return self._model(inputs)
def predict(dataset: tf.data.Dataset,
use_dropout: bool,
output_path: str,
weights_path: Optional[str] = None,
checkpoint_path: Optional[str] = None,
verbose: Optional[bool] = False,
forward_passes_per_image: Optional[int] = 42) -> None:
"""
Run trained SSD on the given data set.
Either the weights path or the checkpoint path must be given. This prevents
a scenario where an untrained network is used to predict.
The prediction results are saved to the output path.
Args:
dataset: the testing data set
use_dropout: if True, DropoutSSD will be used
output_path: the path in which the results should be saved
weights_path: the path to the trained Keras weights (h5 file)
checkpoint_path: the path to the stored checkpoints (Tensorflow checkpoints)
verbose: if True, progress is printed to the standard output
forward_passes_per_image: specifies number of forward passes per image
used by DropoutSSD
"""
if weights_path is None and checkpoint_path is None:
raise ValueError("Either 'weights_path' or 'checkpoint_path' must be given.")
checkpointables = {}
if use_dropout:
checkpointables.update({
'ssd': DropoutSSD(mode='inference_fast', weights_path=weights_path)
})
else:
checkpointables.update({
'ssd': SSD(mode='inference_fast', weights_path=weights_path)
})
if checkpoint_path is not None:
# checkpoint
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
checkpoint = tf.train.Checkpoint(**checkpointables)
checkpoint.restore(latest_checkpoint)
outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables)
# save predictions
filename = 'ssd_predictions.bin'
if use_dropout:
filename = 'dropout-' + filename
with open(output_path + filename, 'wb') as file:
pickle.dump(outputs, file)
if verbose:
print((
f"predict time: {outputs['per_epoch_time']:.2f}, "
))
print("Prediction finished!... save outputs")
def _predict_one_epoch(dataset: tf.data.Dataset,
use_dropout: bool,
forward_passes_per_image: int,
ssd: tf.keras.Model) -> Dict[str, float]:
with summary_ops_v2.always_record_summaries():
epoch_start_time = time.time()
decoded_predictions = []
# go through the data set
for inputs, _ in dataset:
decoded_predictions_batch = []
if use_dropout:
for _ in range(forward_passes_per_image):
decoded_predictions_pass = ssd(inputs)
decoded_predictions_batch.append(decoded_predictions_pass)
else:
decoded_predictions_batch.append(ssd(inputs))
decoded_predictions.append(decoded_predictions_batch)
epoch_end_time = time.time()
per_epoch_time = epoch_end_time - epoch_start_time
# outputs for epoch
outputs = {
'per_epoch_time': per_epoch_time,
'decoded_predictions': decoded_predictions
}
return outputs
def train(dataset: tf.data.Dataset,