Provide top_k to decode function
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -292,6 +292,7 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||||||
entropy_threshold_max,
|
entropy_threshold_max,
|
||||||
confidence_threshold,
|
confidence_threshold,
|
||||||
iou_threshold,
|
iou_threshold,
|
||||||
|
top_k,
|
||||||
output_path,
|
output_path,
|
||||||
coco_path,
|
coco_path,
|
||||||
use_dropout,
|
use_dropout,
|
||||||
|
|||||||
@ -155,6 +155,7 @@ def predict(generator: callable,
|
|||||||
entropy_threshold_max: float,
|
entropy_threshold_max: float,
|
||||||
confidence_threshold: float,
|
confidence_threshold: float,
|
||||||
iou_threshold: float,
|
iou_threshold: float,
|
||||||
|
top_k: int,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
coco_path: str,
|
coco_path: str,
|
||||||
use_dropout: bool,
|
use_dropout: bool,
|
||||||
@ -178,6 +179,7 @@ def predict(generator: callable,
|
|||||||
confidence_threshold: minimum confidence required for box to count as positive
|
confidence_threshold: minimum confidence required for box to count as positive
|
||||||
iou_threshold: all boxes with iou overlap larger than threshold to local maximum box
|
iou_threshold: all boxes with iou overlap larger than threshold to local maximum box
|
||||||
will be suppressed
|
will be suppressed
|
||||||
|
top_k: a maximum of top_k boxes remain after NMS
|
||||||
output_path: the path in which the results should be saved
|
output_path: the path in which the results should be saved
|
||||||
coco_path: the path to the COCO data set
|
coco_path: the path to the COCO data set
|
||||||
use_dropout: if True, multiple forward passes and observations will be used
|
use_dropout: if True, multiple forward passes and observations will be used
|
||||||
@ -203,7 +205,8 @@ def predict(generator: callable,
|
|||||||
decode_func=ssd_output_decoder.decode_detections,
|
decode_func=ssd_output_decoder.decode_detections,
|
||||||
image_size=image_size,
|
image_size=image_size,
|
||||||
confidence_threshold=confidence_threshold,
|
confidence_threshold=confidence_threshold,
|
||||||
iou_threshold=iou_threshold
|
iou_threshold=iou_threshold,
|
||||||
|
top_k=top_k
|
||||||
),
|
),
|
||||||
transform_func=functools.partial(
|
transform_func=functools.partial(
|
||||||
_transform_predictions,
|
_transform_predictions,
|
||||||
@ -357,7 +360,8 @@ def _decode_predictions(predictions: np.ndarray,
|
|||||||
image_size: int,
|
image_size: int,
|
||||||
entropy_threshold: float,
|
entropy_threshold: float,
|
||||||
confidence_threshold: float,
|
confidence_threshold: float,
|
||||||
iou_threshold: float) -> np.ndarray:
|
iou_threshold: float,
|
||||||
|
top_k: int) -> np.ndarray:
|
||||||
return decode_func(
|
return decode_func(
|
||||||
y_pred=predictions,
|
y_pred=predictions,
|
||||||
img_width=image_size,
|
img_width=image_size,
|
||||||
@ -365,7 +369,8 @@ def _decode_predictions(predictions: np.ndarray,
|
|||||||
input_coords="corners",
|
input_coords="corners",
|
||||||
entropy_thresh=entropy_threshold,
|
entropy_thresh=entropy_threshold,
|
||||||
confidence_thresh=confidence_threshold,
|
confidence_thresh=confidence_threshold,
|
||||||
iou_threshold=iou_threshold
|
iou_threshold=iou_threshold,
|
||||||
|
top_k=top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user