Added prediction functionality to ssd module
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -31,6 +31,7 @@ Classes:
|
|||||||
``SSD``: wraps vanilla SSD 300 model
|
``SSD``: wraps vanilla SSD 300 model
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import time
|
import time
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -97,6 +98,100 @@ class DropoutSSD:
|
|||||||
|
|
||||||
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
||||||
return self._model(inputs)
|
return self._model(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def predict(dataset: tf.data.Dataset,
|
||||||
|
use_dropout: bool,
|
||||||
|
output_path: str,
|
||||||
|
weights_path: Optional[str] = None,
|
||||||
|
checkpoint_path: Optional[str] = None,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
|
forward_passes_per_image: Optional[int] = 42) -> None:
|
||||||
|
"""
|
||||||
|
Run trained SSD on the given data set.
|
||||||
|
|
||||||
|
Either the weights path or the checkpoint path must be given. This prevents
|
||||||
|
a scenario where an untrained network is used to predict.
|
||||||
|
|
||||||
|
The prediction results are saved to the output path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the testing data set
|
||||||
|
use_dropout: if True, DropoutSSD will be used
|
||||||
|
output_path: the path in which the results should be saved
|
||||||
|
weights_path: the path to the trained Keras weights (h5 file)
|
||||||
|
checkpoint_path: the path to the stored checkpoints (Tensorflow checkpoints)
|
||||||
|
verbose: if True, progress is printed to the standard output
|
||||||
|
forward_passes_per_image: specifies number of forward passes per image
|
||||||
|
used by DropoutSSD
|
||||||
|
"""
|
||||||
|
if weights_path is None and checkpoint_path is None:
|
||||||
|
raise ValueError("Either 'weights_path' or 'checkpoint_path' must be given.")
|
||||||
|
|
||||||
|
checkpointables = {}
|
||||||
|
if use_dropout:
|
||||||
|
checkpointables.update({
|
||||||
|
'ssd': DropoutSSD(mode='inference_fast', weights_path=weights_path)
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
checkpointables.update({
|
||||||
|
'ssd': SSD(mode='inference_fast', weights_path=weights_path)
|
||||||
|
})
|
||||||
|
|
||||||
|
if checkpoint_path is not None:
|
||||||
|
# checkpoint
|
||||||
|
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
|
||||||
|
checkpoint = tf.train.Checkpoint(**checkpointables)
|
||||||
|
checkpoint.restore(latest_checkpoint)
|
||||||
|
|
||||||
|
outputs = _predict_one_epoch(dataset, use_dropout, forward_passes_per_image, **checkpointables)
|
||||||
|
|
||||||
|
# save predictions
|
||||||
|
filename = 'ssd_predictions.bin'
|
||||||
|
if use_dropout:
|
||||||
|
filename = 'dropout-' + filename
|
||||||
|
|
||||||
|
with open(output_path + filename, 'wb') as file:
|
||||||
|
pickle.dump(outputs, file)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print((
|
||||||
|
f"predict time: {outputs['per_epoch_time']:.2f}, "
|
||||||
|
))
|
||||||
|
print("Prediction finished!... save outputs")
|
||||||
|
|
||||||
|
|
||||||
|
def _predict_one_epoch(dataset: tf.data.Dataset,
|
||||||
|
use_dropout: bool,
|
||||||
|
forward_passes_per_image: int,
|
||||||
|
ssd: tf.keras.Model) -> Dict[str, float]:
|
||||||
|
|
||||||
|
with summary_ops_v2.always_record_summaries():
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
decoded_predictions = []
|
||||||
|
|
||||||
|
# go through the data set
|
||||||
|
for inputs, _ in dataset:
|
||||||
|
decoded_predictions_batch = []
|
||||||
|
if use_dropout:
|
||||||
|
for _ in range(forward_passes_per_image):
|
||||||
|
decoded_predictions_pass = ssd(inputs)
|
||||||
|
decoded_predictions_batch.append(decoded_predictions_pass)
|
||||||
|
else:
|
||||||
|
decoded_predictions_batch.append(ssd(inputs))
|
||||||
|
|
||||||
|
decoded_predictions.append(decoded_predictions_batch)
|
||||||
|
|
||||||
|
epoch_end_time = time.time()
|
||||||
|
per_epoch_time = epoch_end_time - epoch_start_time
|
||||||
|
|
||||||
|
# outputs for epoch
|
||||||
|
outputs = {
|
||||||
|
'per_epoch_time': per_epoch_time,
|
||||||
|
'decoded_predictions': decoded_predictions
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def train(dataset: tf.data.Dataset,
|
def train(dataset: tf.data.Dataset,
|
||||||
|
|||||||
Reference in New Issue
Block a user