Changed code so that images are only saved on first batch
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -182,11 +182,14 @@ def predict(generator: callable,
|
|||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
coco_path=coco_path,
|
coco_path=coco_path,
|
||||||
image_size=image_size),
|
image_size=image_size),
|
||||||
|
decode_func=functools.partial(
|
||||||
|
_decode_predictions,
|
||||||
|
decode_func=ssd_output_decoder.decode_detections_fast,
|
||||||
|
image_size=image_size
|
||||||
|
),
|
||||||
transform_func=functools.partial(
|
transform_func=functools.partial(
|
||||||
_transform_predictions,
|
_transform_predictions,
|
||||||
decode_func=ssd_output_decoder.decode_detections_fast,
|
inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms),
|
||||||
inverse_transform_func=object_detection_2d_misc_utils.apply_inverse_transforms,
|
|
||||||
image_size=image_size),
|
|
||||||
save_func=functools.partial(_save_predictions,
|
save_func=functools.partial(_save_predictions,
|
||||||
output_file=output_file,
|
output_file=output_file,
|
||||||
label_output_file=label_output_file,
|
label_output_file=label_output_file,
|
||||||
@ -264,7 +267,7 @@ def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, st
|
|||||||
|
|
||||||
def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
||||||
dropout_step: callable, vanilla_step: callable,
|
dropout_step: callable, vanilla_step: callable,
|
||||||
save_images: callable,
|
save_images: callable, decode_func: callable,
|
||||||
transform_func: callable, save_func: callable) -> None:
|
transform_func: callable, save_func: callable) -> None:
|
||||||
|
|
||||||
batch_counter = 0
|
batch_counter = 0
|
||||||
@ -277,11 +280,12 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
|||||||
|
|
||||||
if not saved_images:
|
if not saved_images:
|
||||||
save_images(inputs, predictions, custom_string="after-prediction")
|
save_images(inputs, predictions, custom_string="after-prediction")
|
||||||
|
decoded_predictions = decode_func(predictions)
|
||||||
|
if not saved_images:
|
||||||
|
save_images(inputs, decoded_predictions, custom_string="after-decoding")
|
||||||
saved_images = True
|
saved_images = True
|
||||||
transformed_predictions = transform_func(predictions,
|
transformed_predictions = transform_func(decoded_predictions,
|
||||||
inverse_transforms,
|
inverse_transforms)
|
||||||
functools.partial(save_images,
|
|
||||||
inputs))
|
|
||||||
save_func(transformed_predictions, original_labels, filenames,
|
save_func(transformed_predictions, original_labels, filenames,
|
||||||
batch_nr=batch_counter)
|
batch_nr=batch_counter)
|
||||||
|
|
||||||
@ -313,21 +317,21 @@ def _predict_vanilla_step(inputs: np.ndarray, model: tf.keras.models.Model) -> n
|
|||||||
return np.asarray(model.predict_on_batch(inputs))
|
return np.asarray(model.predict_on_batch(inputs))
|
||||||
|
|
||||||
|
|
||||||
def _transform_predictions(predictions: np.ndarray, inverse_transforms: Sequence[np.ndarray],
|
def _decode_predictions(predictions: np.ndarray,
|
||||||
save_images: callable,
|
decode_func: callable,
|
||||||
decode_func: callable, inverse_transform_func: callable,
|
image_size: int) -> np.ndarray:
|
||||||
image_size: int) -> np.ndarray:
|
return decode_func(
|
||||||
|
|
||||||
decoded_predictions = decode_func(
|
|
||||||
y_pred=predictions,
|
y_pred=predictions,
|
||||||
img_width=image_size,
|
img_width=image_size,
|
||||||
img_height=image_size,
|
img_height=image_size,
|
||||||
input_coords="corners"
|
input_coords="corners"
|
||||||
)
|
)
|
||||||
save_images(decoded_predictions, custom_string="after-decoding")
|
|
||||||
transformed_predictions = inverse_transform_func(decoded_predictions, inverse_transforms)
|
|
||||||
|
def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms: Sequence[np.ndarray],
|
||||||
|
inverse_transform_func: callable) -> np.ndarray:
|
||||||
|
|
||||||
return transformed_predictions
|
return inverse_transform_func(decoded_predictions, inverse_transforms)
|
||||||
|
|
||||||
|
|
||||||
def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str],
|
def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.ndarray, filenames: Sequence[str],
|
||||||
|
|||||||
Reference in New Issue
Block a user