Changed filtering to occur per-class
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
parent
aaaea5f942
commit
a41f607ca9
|
@ -274,7 +274,8 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
||||||
paths.output_path,
|
paths.output_path,
|
||||||
conf_obj.paths.coco,
|
conf_obj.paths.coco,
|
||||||
use_dropout,
|
use_dropout,
|
||||||
nr_digits)
|
nr_digits,
|
||||||
|
conf_obj.parameters.nr_classes)
|
||||||
|
|
||||||
|
|
||||||
def _ssd_evaluate(args: argparse.Namespace) -> None:
|
def _ssd_evaluate(args: argparse.Namespace) -> None:
|
||||||
|
|
|
@ -159,7 +159,8 @@ def predict(generator: callable,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
coco_path: str,
|
coco_path: str,
|
||||||
use_dropout: bool,
|
use_dropout: bool,
|
||||||
nr_digits: int) -> None:
|
nr_digits: int,
|
||||||
|
nr_classes: int) -> None:
|
||||||
"""
|
"""
|
||||||
Run trained SSD on the given data set.
|
Run trained SSD on the given data set.
|
||||||
|
|
||||||
|
@ -184,6 +185,7 @@ def predict(generator: callable,
|
||||||
coco_path: the path to the COCO data set
|
coco_path: the path to the COCO data set
|
||||||
use_dropout: if True, multiple forward passes and observations will be used
|
use_dropout: if True, multiple forward passes and observations will be used
|
||||||
nr_digits: number of digits needed to print largest batch number
|
nr_digits: number of digits needed to print largest batch number
|
||||||
|
nr_classes: number of classes
|
||||||
"""
|
"""
|
||||||
output_file, label_output_file = _predict_prepare_paths(output_path, use_dropout)
|
output_file, label_output_file = _predict_prepare_paths(output_path, use_dropout)
|
||||||
|
|
||||||
|
@ -213,7 +215,11 @@ def predict(generator: callable,
|
||||||
image_size=image_size,
|
image_size=image_size,
|
||||||
confidence_threshold=confidence_threshold,
|
confidence_threshold=confidence_threshold,
|
||||||
),
|
),
|
||||||
apply_entropy_threshold_func=_apply_entropy_threshold,
|
apply_entropy_threshold_func=functools.partial(
|
||||||
|
_apply_entropy_filtering,
|
||||||
|
confidence_threshold=confidence_threshold,
|
||||||
|
nr_classes=nr_classes
|
||||||
|
),
|
||||||
apply_top_k_func=functools.partial(
|
apply_top_k_func=functools.partial(
|
||||||
_apply_top_k,
|
_apply_top_k,
|
||||||
top_k=top_k
|
top_k=top_k
|
||||||
|
@ -332,7 +338,8 @@ def _predict_loop(generator: Generator, use_dropout: bool, steps_per_epoch: int,
|
||||||
observations = get_observations_func(decoded_predictions)
|
observations = get_observations_func(decoded_predictions)
|
||||||
for entropy_threshold in entropy_thresholds:
|
for entropy_threshold in entropy_thresholds:
|
||||||
if use_dropout:
|
if use_dropout:
|
||||||
decoded_predictions = apply_entropy_threshold_func(observations, entropy_threshold=entropy_threshold)
|
decoded_predictions = apply_entropy_threshold_func(observations,
|
||||||
|
entropy_threshold=entropy_threshold)
|
||||||
decoded_predictions = apply_top_k_func(decoded_predictions)
|
decoded_predictions = apply_top_k_func(decoded_predictions)
|
||||||
else:
|
else:
|
||||||
decoded_predictions = decode_func(predictions, entropy_threshold=entropy_threshold)
|
decoded_predictions = decode_func(predictions, entropy_threshold=entropy_threshold)
|
||||||
|
@ -408,7 +415,10 @@ def _decode_predictions_dropout(predictions: np.ndarray,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _apply_entropy_threshold(observations: Sequence[np.ndarray], entropy_threshold: float) -> List[np.ndarray]:
|
def _apply_entropy_filtering(observations: Sequence[np.ndarray],
|
||||||
|
entropy_threshold: float,
|
||||||
|
confidence_threshold: float,
|
||||||
|
nr_classes: int) -> List[np.ndarray]:
|
||||||
final_observations = []
|
final_observations = []
|
||||||
batch_size = len(observations)
|
batch_size = len(observations)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
@ -417,10 +427,19 @@ def _apply_entropy_threshold(observations: Sequence[np.ndarray], entropy_thresho
|
||||||
continue
|
continue
|
||||||
|
|
||||||
filtered_image_observations = observations[i][observations[i][:, -1] < entropy_threshold]
|
filtered_image_observations = observations[i][observations[i][:, -1] < entropy_threshold]
|
||||||
final_image_observations = np.copy(filtered_image_observations[:, -8:-1])
|
final_image_observations = []
|
||||||
final_image_observations[:, 0] = np.argmax(filtered_image_observations[:, :-5], axis=-1)
|
for class_id in range(1, nr_classes):
|
||||||
final_image_observations[:, 1] = np.amax(filtered_image_observations[:, :-5], axis=-1)
|
single_class = filtered_image_observations[:, [class_id, -1, -5, -4, -3, -2]]
|
||||||
final_image_observations[:, 2] = filtered_image_observations[:, -1]
|
threshold_met = single_class[single_class[:, 0] > confidence_threshold]
|
||||||
|
if threshold_met.shape[0] > 0:
|
||||||
|
output = np.zeros((single_class.shape[0], single_class.shape[1] + 1))
|
||||||
|
output[:, 0] = class_id
|
||||||
|
output[:, 1:] = single_class
|
||||||
|
final_image_observations.append(output)
|
||||||
|
if final_image_observations:
|
||||||
|
final_image_observations = np.concatenate(final_image_observations, axis=0)
|
||||||
|
else:
|
||||||
|
final_image_observations = np.array(final_image_observations)
|
||||||
final_observations.append(final_image_observations)
|
final_observations.append(final_image_observations)
|
||||||
|
|
||||||
return final_observations
|
return final_observations
|
||||||
|
|
Loading…
Reference in New Issue