Added top k filtering for bayesian ssd
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
parent
599f7d8957
commit
9b3d9a550b
|
@ -214,6 +214,10 @@ def predict(generator: callable,
|
||||||
confidence_threshold=confidence_threshold,
|
confidence_threshold=confidence_threshold,
|
||||||
),
|
),
|
||||||
apply_entropy_threshold_func=_apply_entropy_threshold,
|
apply_entropy_threshold_func=_apply_entropy_threshold,
|
||||||
|
apply_top_k_func=functools.partial(
|
||||||
|
_apply_top_k,
|
||||||
|
top_k=top_k
|
||||||
|
),
|
||||||
get_observations_func=_get_observations,
|
get_observations_func=_get_observations,
|
||||||
transform_func=functools.partial(
|
transform_func=functools.partial(
|
||||||
_transform_predictions,
|
_transform_predictions,
|
||||||
|
@ -300,7 +304,7 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
||||||
dropout_step: callable, vanilla_step: callable,
|
dropout_step: callable, vanilla_step: callable,
|
||||||
save_images: callable, decode_func: callable,
|
save_images: callable, decode_func: callable,
|
||||||
decode_func_dropout: callable, get_observations_func: callable,
|
decode_func_dropout: callable, get_observations_func: callable,
|
||||||
apply_entropy_threshold_func: callable,
|
apply_entropy_threshold_func: callable, apply_top_k_func: callable,
|
||||||
transform_func: callable, save_func: callable,
|
transform_func: callable, save_func: callable,
|
||||||
use_entropy_threshold: bool, entropy_threshold_min: float,
|
use_entropy_threshold: bool, entropy_threshold_min: float,
|
||||||
entropy_threshold_max: float) -> None:
|
entropy_threshold_max: float) -> None:
|
||||||
|
@ -329,6 +333,7 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
||||||
for entropy_threshold in entropy_thresholds:
|
for entropy_threshold in entropy_thresholds:
|
||||||
if use_dropout:
|
if use_dropout:
|
||||||
decoded_predictions = apply_entropy_threshold_func(observations, entropy_threshold=entropy_threshold)
|
decoded_predictions = apply_entropy_threshold_func(observations, entropy_threshold=entropy_threshold)
|
||||||
|
decoded_predictions = apply_top_k_func(decoded_predictions)
|
||||||
else:
|
else:
|
||||||
decoded_predictions = decode_func(predictions, entropy_threshold=entropy_threshold)
|
decoded_predictions = decode_func(predictions, entropy_threshold=entropy_threshold)
|
||||||
if not saved_images_decoding:
|
if not saved_images_decoding:
|
||||||
|
@ -421,6 +426,29 @@ def _apply_entropy_threshold(observations: Sequence[np.ndarray], entropy_thresho
|
||||||
return final_observations
|
return final_observations
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_top_k(detections: Sequence[np.ndarray], top_k: float) -> List[np.ndarray]:
|
||||||
|
final_detections = []
|
||||||
|
batch_size = len(detections)
|
||||||
|
data_type = np.dtype([('image_id', np.int32),
|
||||||
|
('confidence', 'f4'),
|
||||||
|
('entropy', 'f4'),
|
||||||
|
('xmin', 'f4'),
|
||||||
|
('ymin', 'f4'),
|
||||||
|
('xmax', 'f4'),
|
||||||
|
('ymax', 'f4')])
|
||||||
|
for i in range(batch_size):
|
||||||
|
image_detections = detections[i]
|
||||||
|
image_detections_structured = np.array(image_detections, dtype=data_type)
|
||||||
|
descending_indices = np.argsort(-image_detections_structured['confidence'])
|
||||||
|
image_detections_sorted = np.asarray(image_detections_structured[descending_indices])
|
||||||
|
top_k_indices = np.argpartition(image_detections_sorted[:, 1],
|
||||||
|
kth=image_detections_sorted.shape[0] - top_k,
|
||||||
|
axis=0)[image_detections_sorted.shape[0] - top_k:]
|
||||||
|
final_detections.append(image_detections_sorted[top_k_indices])
|
||||||
|
|
||||||
|
return final_detections
|
||||||
|
|
||||||
|
|
||||||
def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms: Sequence[np.ndarray],
|
def _transform_predictions(decoded_predictions: np.ndarray, inverse_transforms: Sequence[np.ndarray],
|
||||||
inverse_transform_func: callable) -> np.ndarray:
|
inverse_transform_func: callable) -> np.ndarray:
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue