Provide top_k to decode function

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-08-12 19:26:12 +02:00
parent ef38ba29c1
commit 639d04747d
2 changed files with 9 additions and 3 deletions

View File

@ -292,6 +292,7 @@ def _ssd_test(args: argparse.Namespace) -> None:
entropy_threshold_max, entropy_threshold_max,
confidence_threshold, confidence_threshold,
iou_threshold, iou_threshold,
top_k,
output_path, output_path,
coco_path, coco_path,
use_dropout, use_dropout,

View File

@ -155,6 +155,7 @@ def predict(generator: callable,
entropy_threshold_max: float, entropy_threshold_max: float,
confidence_threshold: float, confidence_threshold: float,
iou_threshold: float, iou_threshold: float,
top_k: int,
output_path: str, output_path: str,
coco_path: str, coco_path: str,
use_dropout: bool, use_dropout: bool,
@ -178,6 +179,7 @@ def predict(generator: callable,
confidence_threshold: minimum confidence required for box to count as positive confidence_threshold: minimum confidence required for box to count as positive
iou_threshold: all boxes with iou overlap larger than threshold to local maximum box iou_threshold: all boxes with iou overlap larger than threshold to local maximum box
will be suppressed will be suppressed
top_k: a maximum of top_k boxes remain after NMS
output_path: the path in which the results should be saved output_path: the path in which the results should be saved
coco_path: the path to the COCO data set coco_path: the path to the COCO data set
use_dropout: if True, multiple forward passes and observations will be used use_dropout: if True, multiple forward passes and observations will be used
@ -203,7 +205,8 @@ def predict(generator: callable,
decode_func=ssd_output_decoder.decode_detections, decode_func=ssd_output_decoder.decode_detections,
image_size=image_size, image_size=image_size,
confidence_threshold=confidence_threshold, confidence_threshold=confidence_threshold,
iou_threshold=iou_threshold iou_threshold=iou_threshold,
top_k=top_k
), ),
transform_func=functools.partial( transform_func=functools.partial(
_transform_predictions, _transform_predictions,
@ -357,7 +360,8 @@ def _decode_predictions(predictions: np.ndarray,
image_size: int, image_size: int,
entropy_threshold: float, entropy_threshold: float,
confidence_threshold: float, confidence_threshold: float,
iou_threshold: float) -> np.ndarray: iou_threshold: float,
top_k: int) -> np.ndarray:
return decode_func( return decode_func(
y_pred=predictions, y_pred=predictions,
img_width=image_size, img_width=image_size,
@ -365,7 +369,8 @@ def _decode_predictions(predictions: np.ndarray,
input_coords="corners", input_coords="corners",
entropy_thresh=entropy_threshold, entropy_thresh=entropy_threshold,
confidence_thresh=confidence_threshold, confidence_thresh=confidence_threshold,
iou_threshold=iou_threshold iou_threshold=iou_threshold,
top_k=top_k
) )