diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 3303b85..566ec3e 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -270,6 +270,9 @@ def _get_observations(detections: Sequence[Sequence[np.ndarray]]) -> List[List[n for i in range(batch_size): print(f"{i}th batch element") detections_image = np.asarray(detections[i]) + class_ids = np.argmax(detections_image[:, :-12], + axis=-1) + detections_image = detections_image[class_ids != 0] print(detections_image.shape) overlaps = bounding_box_utils.iou(detections_image[:, -12:-8], detections_image[:, -12:-8],