diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 2f481dd..74647e5 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -441,11 +441,14 @@ def _apply_top_k(detections: Sequence[np.ndarray], top_k: float) -> List[np.ndar image_detections_structured = np.core.records.fromarrays(image_detections.transpose(), dtype=data_type) descending_indices = np.argsort(-image_detections_structured['confidence']) - image_detections_sorted = np.asarray(image_detections_structured[descending_indices]) - top_k_indices = np.argpartition(image_detections_sorted[:, 1], - kth=image_detections_sorted.shape[0] - top_k, - axis=0)[image_detections_sorted.shape[0] - top_k:] - final_detections.append(image_detections_sorted[top_k_indices]) + image_detections_sorted = image_detections[descending_indices] + if image_detections_sorted.shape[0] > top_k: + top_k_indices = np.argpartition(image_detections_sorted[:, 1], + kth=image_detections_sorted.shape[0] - top_k, + axis=0)[image_detections_sorted.shape[0] - top_k:] + final_detections.append(image_detections_sorted[top_k_indices]) + else: + final_detections.append(image_detections_sorted) return final_detections