Changed predict function to conform to clean code standards

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-11 14:52:59 +02:00
parent 718467e43c
commit fbcf2c261c

View File

@ -27,10 +27,11 @@ Functions:
predict(...): runs trained SSD/DropoutSSD on a given data set
train(...): trains the SSD/DropoutSSD on a given data set
"""
import functools
import os
import pickle
from typing import List, Sequence, Tuple
from typing import List, Sequence, Tuple, Generator
from typing import Optional
import numpy as np
@ -142,12 +143,13 @@ def compile_model(model: tf.keras.models.Model, learning_rate: float, loss_func:
def predict(generator: callable,
model: tf.keras.models.Model,
steps_per_epoch: int,
ssd_model: tf.keras.models.Model,
use_dropout: bool,
forward_passes_per_image: int,
image_size: int,
batch_size: int,
forward_passes_per_image: int,
output_path: str,
use_dropout: bool,
nr_digits: int) -> None:
"""
Run trained SSD on the given data set.
@ -156,69 +158,117 @@ def predict(generator: callable,
Args:
generator: generator of test data
model: compiled and trained Keras model
steps_per_epoch: number of batches per epoch
ssd_model: compiled and trained Keras model
use_dropout: if True, multiple forward passes and observations will be used
image_size: size of input images to model
batch_size: number of items in every batch
forward_passes_per_image: specifies number of forward passes per image
used by DropoutSSD
image_size: size of input images to model
output_path: the path in which the results should be saved
use_dropout: if True, multiple forward passes and observations will be used
nr_digits: number of digits needed to print largest batch number
"""
# prepare filename
filename = 'ssd_predictions'
label_filename = 'ssd_labels'
output_file, label_output_file = _predict_prepare_paths(output_path, use_dropout)
_predict_loop(generator, use_dropout, steps_per_epoch,
functools.partial(_predict_dropout_step,
model=model,
get_observations_func=_get_observations,
batch_size=batch_size,
forward_passes_per_image=forward_passes_per_image),
functools.partial(_predict_vanilla_step, model=model),
functools.partial(_transform_predictions,
decode_func=ssd_output_decoder.decode_detections_fast,
inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms,
image_size=image_size),
functools.partial(_save_predictions,
output_file=output_file,
label_output_file=label_output_file,
nr_digits=nr_digits))
def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, str]:
filename = "ssd_predictions"
label_filename = "ssd_labels"
if use_dropout:
filename = f"dropout-{filename}"
output_file = os.path.join(output_path, filename)
label_output_file = os.path.join(output_path, label_filename)
image_size = (image_size, image_size)
return output_file, label_output_file
def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
dropout_step: callable, vanilla_step: callable,
transform_func: callable, save_func: callable) -> None:
batch_counter = 0
for x, filenames, inverse_transforms, original_labels in generator:
for inputs, filenames, inverse_transforms, original_labels in generator:
if use_dropout:
detections = None
batch_size = None
for _ in range(forward_passes_per_image):
predictions = ssd_model.predict_on_batch(x)
if batch_size is None:
batch_size = predictions.shape[0]
if detections is None:
detections = [[] for _ in range(batch_size)]
for i in range(batch_size):
batch_item = predictions[i]
detections[i].extend(batch_item)
# do observation stuff
predictions = np.asarray(_get_observations(detections))
predictions = dropout_step(inputs)
else:
predictions = np.asarray(ssd_model.predict_on_batch(x))
predictions = vanilla_step(inputs)
decoded_predictions_batch = ssd_output_decoder.decode_detections_fast(
y_pred=predictions,
img_width=image_size[0],
img_height=image_size[1],
)
transformed_predictions_batch = object_detection_2d_misc_utils.apply_inverse_transforms(
decoded_predictions_batch, inverse_transforms
)
# save prediction results to prevent memory issues
counter_str = str(batch_counter).zfill(nr_digits)
filename = f"{output_file}-{counter_str}.bin"
label_filename = f"{label_output_file}-{counter_str}.bin"
with open(filename, 'wb') as file, open(label_filename, 'wb') as label_file:
pickle.dump(transformed_predictions_batch, file)
pickle.dump({'labels': original_labels, 'filenames': filenames}, label_file)
transformed_predictions = transform_func(predictions, inverse_transforms)
save_func(transformed_predictions, original_labels, filenames,
batch_counter)
batch_counter += 1
# we only do one epoch for prediction
if batch_counter == steps_per_epoch:
break
def _predict_dropout_step(inputs: np.ndarray, model: tf.keras.models.Model,
get_observations_func: callable,
batch_size: int, forward_passes_per_image: int) -> np.ndarray:
detections = [[] for _ in range(batch_size)]
for _ in range(forward_passes_per_image):
predictions = model.predict_on_batch(inputs)
for i in range(batch_size):
batch_item = predictions[i]
detections[i].extend(batch_item)
observations = np.asarray(get_observations_func(detections))
return observations
def _predict_vanilla_step(inputs: np.ndarray, model: tf.keras.models.Model) -> np.ndarray:
return np.asarray(model.predict_on_batch(inputs))
def _transform_predictions(predictions: np.ndarray, inverse_transforms: Sequence[np.ndarray],
decode_func: callable, inverse_transform_func: callable,
image_size: int) -> np.ndarray:
decoded_predictions = decode_func(
y_pred=predictions,
img_width=image_size,
img_height=image_size
)
transformed_predictions = inverse_transform_func(decoded_predictions, inverse_transforms)
return transformed_predictions
def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str],
output_file: str, label_output_file: str,
batch_nr: int, nr_digits: int) -> None:
counter_str = str(batch_nr).zfill(nr_digits)
filename = f"{output_file}-{counter_str}.bin"
label_filename = f"{label_output_file}-{counter_str}.bin"
with open(filename, "wb") as file, open(label_filename, "wb") as label_file:
pickle.dump(transformed_predictions, file)
pickle.dump({"labels": original_labels, "filenames": filenames}, label_file)
def train(train_generator: callable,
steps_per_epoch_train: int,
val_generator: callable,