From 13252b37b29e494d61a44ea27e755321c2e7d68a Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Mon, 15 Jul 2019 11:59:35 +0200 Subject: [PATCH] Changed code so that images are only saved on first batch Signed-off-by: Jim Martens --- src/twomartens/masterthesis/ssd.py | 38 +++++++++++++++++------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index e391976..1d9bcde 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -182,11 +182,14 @@ def predict(generator: callable, output_path=output_path, coco_path=coco_path, image_size=image_size), + decode_func=functools.partial( + _decode_predictions, + decode_func=ssd_output_decoder.decode_detections_fast, + image_size=image_size + ), transform_func=functools.partial( _transform_predictions, - decode_func=ssd_output_decoder.decode_detections_fast, - inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms, - image_size=image_size), + inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms), save_func=functools.partial(_save_predictions, output_file=output_file, label_output_file=label_output_file, @@ -264,7 +267,7 @@ def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, st def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int, dropout_step: callable, vanilla_step: callable, - save_images: callable, + save_images: callable, decode_func: callable, transform_func: callable, save_func: callable) -> None: batch_counter = 0 @@ -277,11 +280,12 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int, if not saved_images: save_images(inputs, predictions, custom_string="after-prediction") + decoded_predictions = decode_func(predictions) + if not saved_images: + save_images(inputs, decoded_predictions, custom_string="after-decoding") saved_images = True - transformed_predictions = transform_func(predictions, - inverse_transforms, - functools.partial(save_images, - inputs)) + transformed_predictions = transform_func(decoded_predictions, + inverse_transforms) save_func(transformed_predictions, original_labels, filenames, batch_nr=batch_counter) @@ -313,21 +317,21 @@ def _predict_vanilla_step(inputs: np.ndarray, model: tf.keras.models.Model) -> n return np.asarray(model.predict_on_batch(inputs)) -def _transform_predictions(predictions: np.ndarray, inverse_transforms: Sequence[np.ndarray], - save_images: callable, - decode_func: callable, inverse_transform_func: callable, - image_size: int) -> np.ndarray: - - decoded_predictions = decode_func( +def _decode_predictions(predictions: np.ndarray, + decode_func: callable, + image_size: int) -> np.ndarray: + return decode_func( y_pred=predictions, img_width=image_size, img_height=image_size, input_coords="corners" ) - save_images(decoded_predictions, custom_string="after-decoding") - transformed_predictions = inverse_transform_func(decoded_predictions, inverse_transforms) + + +def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms: Sequence[np.ndarray], + inverse_transform_func: callable) -> np.ndarray: - return transformed_predictions + return inverse_transform_func(decoded_predictions, inverse_transforms) def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str],