Changed predict function to conform to clean code standards
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -27,10 +27,11 @@ Functions:
|
|||||||
predict(...): runs trained SSD/DropoutSSD on a given data set
|
predict(...): runs trained SSD/DropoutSSD on a given data set
|
||||||
train(...): trains the SSD/DropoutSSD on a given data set
|
train(...): trains the SSD/DropoutSSD on a given data set
|
||||||
"""
|
"""
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
from typing import List, Sequence, Tuple
|
from typing import List, Sequence, Tuple, Generator
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -142,12 +143,13 @@ def compile_model(model: tf.keras.models.Model, learning_rate: float, loss_func:
|
|||||||
|
|
||||||
|
|
||||||
def predict(generator: callable,
|
def predict(generator: callable,
|
||||||
|
model: tf.keras.models.Model,
|
||||||
steps_per_epoch: int,
|
steps_per_epoch: int,
|
||||||
ssd_model: tf.keras.models.Model,
|
|
||||||
use_dropout: bool,
|
|
||||||
forward_passes_per_image: int,
|
|
||||||
image_size: int,
|
image_size: int,
|
||||||
|
batch_size: int,
|
||||||
|
forward_passes_per_image: int,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
|
use_dropout: bool,
|
||||||
nr_digits: int) -> None:
|
nr_digits: int) -> None:
|
||||||
"""
|
"""
|
||||||
Run trained SSD on the given data set.
|
Run trained SSD on the given data set.
|
||||||
@ -156,67 +158,115 @@ def predict(generator: callable,
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
generator: generator of test data
|
generator: generator of test data
|
||||||
|
model: compiled and trained Keras model
|
||||||
steps_per_epoch: number of batches per epoch
|
steps_per_epoch: number of batches per epoch
|
||||||
ssd_model: compiled and trained Keras model
|
image_size: size of input images to model
|
||||||
use_dropout: if True, multiple forward passes and observations will be used
|
batch_size: number of items in every batch
|
||||||
forward_passes_per_image: specifies number of forward passes per image
|
forward_passes_per_image: specifies number of forward passes per image
|
||||||
used by DropoutSSD
|
used by DropoutSSD
|
||||||
image_size: size of input images to model
|
|
||||||
output_path: the path in which the results should be saved
|
output_path: the path in which the results should be saved
|
||||||
|
use_dropout: if True, multiple forward passes and observations will be used
|
||||||
nr_digits: number of digits needed to print largest batch number
|
nr_digits: number of digits needed to print largest batch number
|
||||||
"""
|
"""
|
||||||
# prepare filename
|
output_file, label_output_file = _predict_prepare_paths(output_path, use_dropout)
|
||||||
filename = 'ssd_predictions'
|
|
||||||
label_filename = 'ssd_labels'
|
_predict_loop(generator, use_dropout, steps_per_epoch,
|
||||||
|
functools.partial(_predict_dropout_step,
|
||||||
|
model=model,
|
||||||
|
get_observations_func=_get_observations,
|
||||||
|
batch_size=batch_size,
|
||||||
|
forward_passes_per_image=forward_passes_per_image),
|
||||||
|
functools.partial(_predict_vanilla_step, model=model),
|
||||||
|
functools.partial(_transform_predictions,
|
||||||
|
decode_func=ssd_output_decoder.decode_detections_fast,
|
||||||
|
inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms,
|
||||||
|
image_size=image_size),
|
||||||
|
functools.partial(_save_predictions,
|
||||||
|
output_file=output_file,
|
||||||
|
label_output_file=label_output_file,
|
||||||
|
nr_digits=nr_digits))
|
||||||
|
|
||||||
|
|
||||||
|
def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, str]:
|
||||||
|
filename = "ssd_predictions"
|
||||||
|
label_filename = "ssd_labels"
|
||||||
if use_dropout:
|
if use_dropout:
|
||||||
filename = f"dropout-{filename}"
|
filename = f"dropout-{filename}"
|
||||||
|
|
||||||
output_file = os.path.join(output_path, filename)
|
output_file = os.path.join(output_path, filename)
|
||||||
label_output_file = os.path.join(output_path, label_filename)
|
label_output_file = os.path.join(output_path, label_filename)
|
||||||
image_size = (image_size, image_size)
|
|
||||||
|
return output_file, label_output_file
|
||||||
|
|
||||||
|
|
||||||
|
def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
||||||
|
dropout_step: callable, vanilla_step: callable,
|
||||||
|
transform_func: callable, save_func: callable) -> None:
|
||||||
|
|
||||||
batch_counter = 0
|
batch_counter = 0
|
||||||
for x, filenames, inverse_transforms, original_labels in generator:
|
for inputs, filenames, inverse_transforms, original_labels in generator:
|
||||||
if use_dropout:
|
if use_dropout:
|
||||||
detections = None
|
predictions = dropout_step(inputs)
|
||||||
batch_size = None
|
else:
|
||||||
for _ in range(forward_passes_per_image):
|
predictions = vanilla_step(inputs)
|
||||||
predictions = ssd_model.predict_on_batch(x)
|
|
||||||
if batch_size is None:
|
transformed_predictions = transform_func(predictions, inverse_transforms)
|
||||||
batch_size = predictions.shape[0]
|
save_func(transformed_predictions, original_labels, filenames,
|
||||||
if detections is None:
|
batch_counter)
|
||||||
|
|
||||||
|
batch_counter += 1
|
||||||
|
|
||||||
|
if batch_counter == steps_per_epoch:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def _predict_dropout_step(inputs: np.ndarray, model: tf.keras.models.Model,
|
||||||
|
get_observations_func: callable,
|
||||||
|
batch_size: int, forward_passes_per_image: int) -> np.ndarray:
|
||||||
|
|
||||||
detections = [[] for _ in range(batch_size)]
|
detections = [[] for _ in range(batch_size)]
|
||||||
|
|
||||||
|
for _ in range(forward_passes_per_image):
|
||||||
|
predictions = model.predict_on_batch(inputs)
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
batch_item = predictions[i]
|
batch_item = predictions[i]
|
||||||
detections[i].extend(batch_item)
|
detections[i].extend(batch_item)
|
||||||
|
|
||||||
# do observation stuff
|
observations = np.asarray(get_observations_func(detections))
|
||||||
predictions = np.asarray(_get_observations(detections))
|
|
||||||
else:
|
|
||||||
predictions = np.asarray(ssd_model.predict_on_batch(x))
|
|
||||||
|
|
||||||
decoded_predictions_batch = ssd_output_decoder.decode_detections_fast(
|
return observations
|
||||||
|
|
||||||
|
|
||||||
|
def _predict_vanilla_step(inputs: np.ndarray, model: tf.keras.models.Model) -> np.ndarray:
|
||||||
|
return np.asarray(model.predict_on_batch(inputs))
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_predictions(predictions: np.ndarray, inverse_transforms: Sequence[np.ndarray],
|
||||||
|
decode_func: callable, inverse_transform_func: callable,
|
||||||
|
image_size: int) -> np.ndarray:
|
||||||
|
|
||||||
|
decoded_predictions = decode_func(
|
||||||
y_pred=predictions,
|
y_pred=predictions,
|
||||||
img_width=image_size[0],
|
img_width=image_size,
|
||||||
img_height=image_size[1],
|
img_height=image_size
|
||||||
)
|
|
||||||
transformed_predictions_batch = object_detection_2d_misc_utils.apply_inverse_transforms(
|
|
||||||
decoded_predictions_batch, inverse_transforms
|
|
||||||
)
|
)
|
||||||
|
transformed_predictions = inverse_transform_func(decoded_predictions, inverse_transforms)
|
||||||
|
|
||||||
# save prediction results to prevent memory issues
|
return transformed_predictions
|
||||||
counter_str = str(batch_counter).zfill(nr_digits)
|
|
||||||
|
|
||||||
|
def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str],
|
||||||
|
output_file: str, label_output_file: str,
|
||||||
|
batch_nr: int, nr_digits: int) -> None:
|
||||||
|
|
||||||
|
counter_str = str(batch_nr).zfill(nr_digits)
|
||||||
filename = f"{output_file}-{counter_str}.bin"
|
filename = f"{output_file}-{counter_str}.bin"
|
||||||
label_filename = f"{label_output_file}-{counter_str}.bin"
|
label_filename = f"{label_output_file}-{counter_str}.bin"
|
||||||
|
|
||||||
with open(filename, 'wb') as file, open(label_filename, 'wb') as label_file:
|
with open(filename, "wb") as file, open(label_filename, "wb") as label_file:
|
||||||
pickle.dump(transformed_predictions_batch, file)
|
pickle.dump(transformed_predictions, file)
|
||||||
pickle.dump({'labels': original_labels, 'filenames': filenames}, label_file)
|
pickle.dump({"labels": original_labels, "filenames": filenames}, label_file)
|
||||||
|
|
||||||
batch_counter += 1
|
|
||||||
# we only do one epoch for prediction
|
|
||||||
if batch_counter == steps_per_epoch:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def train(train_generator: callable,
|
def train(train_generator: callable,
|
||||||
|
|||||||
Reference in New Issue
Block a user