diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index afe2d62..9e5a4f4 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -214,6 +214,10 @@ def predict(generator: callable, confidence_threshold=confidence_threshold, ), apply_entropy_threshold_func=_apply_entropy_threshold, + apply_top_k_func=functools.partial( + _apply_top_k, + top_k=top_k + ), get_observations_func=_get_observations, transform_func=functools.partial( _transform_predictions, @@ -300,7 +304,7 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int, dropout_step: callable, vanilla_step: callable, save_images: callable, decode_func: callable, decode_func_dropout: callable, get_observations_func: callable, - apply_entropy_threshold_func: callable, + apply_entropy_threshold_func: callable, apply_top_k_func: callable, transform_func: callable, save_func: callable, use_entropy_threshold: bool, entropy_threshold_min: float, entropy_threshold_max: float) -> None: @@ -329,6 +333,7 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int, for entropy_threshold in entropy_thresholds: if use_dropout: decoded_predictions = apply_entropy_threshold_func(observations, entropy_threshold=entropy_threshold) + decoded_predictions = apply_top_k_func(decoded_predictions) else: decoded_predictions = decode_func(predictions, entropy_threshold=entropy_threshold) if not saved_images_decoding: @@ -421,6 +426,29 @@ def _apply_entropy_threshold(observations: Sequence[np.ndarray], entropy_thresho return final_observations +def _apply_top_k(detections: Sequence[np.ndarray], top_k: float) -> List[np.ndarray]: + final_detections = [] + batch_size = len(detections) + data_type = np.dtype([('image_id', np.int32), + ('confidence', 'f4'), + ('entropy', 'f4'), + ('xmin', 'f4'), + ('ymin', 'f4'), + ('xmax', 'f4'), + ('ymax', 'f4')]) + for i in range(batch_size): + image_detections = detections[i] + image_detections_structured = np.array(image_detections, dtype=data_type) + descending_indices = np.argsort(-image_detections_structured['confidence']) + image_detections_sorted = np.asarray(image_detections_structured[descending_indices]) + top_k_indices = np.argpartition(image_detections_sorted[:, 1], + kth=image_detections_sorted.shape[0] - top_k, + axis=0)[image_detections_sorted.shape[0] - top_k:] + final_detections.append(image_detections_sorted[top_k_indices]) + + return final_detections + + def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms: Sequence[np.ndarray], inverse_transform_func: callable) -> np.ndarray: