From 639d04747db1aee9a5840124891dcf016be4d0e4 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Mon, 12 Aug 2019 19:26:12 +0200 Subject: [PATCH] Provide top_k to decode function Signed-off-by: Jim Martens --- src/twomartens/masterthesis/cli.py | 1 + src/twomartens/masterthesis/ssd.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index e2efd59..d982ba5 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -292,6 +292,7 @@ def _ssd_test(args: argparse.Namespace) -> None: entropy_threshold_max, confidence_threshold, iou_threshold, + top_k, output_path, coco_path, use_dropout, diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index bd78aff..4dc3af9 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -155,6 +155,7 @@ def predict(generator: callable, entropy_threshold_max: float, confidence_threshold: float, iou_threshold: float, + top_k: int, output_path: str, coco_path: str, use_dropout: bool, @@ -178,6 +179,7 @@ def predict(generator: callable, confidence_threshold: minimum confidence required for box to count as positive iou_threshold: all boxes with iou overlap larger than threshold to local maximum box will be suppressed + top_k: a maximum of top_k boxes remain after NMS output_path: the path in which the results should be saved coco_path: the path to the COCO data set use_dropout: if True, multiple forward passes and observations will be used @@ -203,7 +205,8 @@ def predict(generator: callable, decode_func=ssd_output_decoder.decode_detections, image_size=image_size, confidence_threshold=confidence_threshold, - iou_threshold=iou_threshold + iou_threshold=iou_threshold, + top_k=top_k ), transform_func=functools.partial( _transform_predictions, @@ -357,7 +360,8 @@ def _decode_predictions(predictions: np.ndarray, image_size: int, entropy_threshold: float, confidence_threshold: float, - iou_threshold: float) -> np.ndarray: + iou_threshold: float, + top_k: int) -> np.ndarray: return decode_func( y_pred=predictions, img_width=image_size, @@ -365,7 +369,8 @@ def _decode_predictions(predictions: np.ndarray, input_coords="corners", entropy_thresh=entropy_threshold, confidence_thresh=confidence_threshold, - iou_threshold=iou_threshold + iou_threshold=iou_threshold, + top_k=top_k )