diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 550c88a..738e7c7 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -424,7 +424,8 @@ def _apply_entropy_filtering(observations: Sequence[np.ndarray], entropy_threshold: float, confidence_threshold: float, iou_threshold: float, - nr_classes: int) -> List[np.ndarray]: + nr_classes: int, + use_nms: bool = True) -> List[np.ndarray]: final_observations = [] batch_size = len(observations) for i in range(batch_size): @@ -438,10 +439,13 @@ def _apply_entropy_filtering(observations: Sequence[np.ndarray], single_class = filtered_image_observations[:, [class_id, -5, -4, -3, -2]] threshold_met = single_class[single_class[:, 0] > confidence_threshold] if threshold_met.shape[0] > 0: - maxima = ssd_output_decoder._greedy_nms(threshold_met, iou_threshold=iou_threshold) - maxima_output = np.zeros((maxima.shape[0], maxima.shape[1] + 1)) + if use_nms: + maxima = ssd_output_decoder._greedy_nms(threshold_met, iou_threshold=iou_threshold) + maxima_output = np.zeros((maxima.shape[0], maxima.shape[1] + 1)) + else: + maxima_output = np.zeros((threshold_met.shape[0], threshold_met.shape[1] + 1)) maxima_output[:, 0] = class_id - maxima_output[:, 1:] = maxima + maxima_output[:, 1:] = maxima if use_nms else threshold_met final_image_observations.append(maxima_output) if final_image_observations: final_image_observations = np.concatenate(final_image_observations, axis=0)