Added calculation of open set error
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -96,7 +96,8 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
|
||||
iou_threshold: float = 0.5,
|
||||
border_pixels: str = "include",
|
||||
sorting_algorithm: str = "quicksort") -> Tuple[List[np.ndarray], List[np.ndarray],
|
||||
List[np.ndarray], List[np.ndarray]]:
|
||||
List[np.ndarray], List[np.ndarray],
|
||||
int]:
|
||||
"""
|
||||
Matches predictions to ground truth boxes.
|
||||
|
||||
@ -120,10 +121,11 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
|
||||
|
||||
Returns:
|
||||
true positives, false positives, cumulative true positives, and cumulative false positives for
|
||||
each class
|
||||
each class, open set error as defined by Miller et al
|
||||
"""
|
||||
true_positives = [[]] # The false positives for each class, sorted by descending confidence.
|
||||
false_positives = [[]] # The true positives for each class, sorted by descending confidence.
|
||||
open_set_error = 0
|
||||
cumulative_true_positives = [[]]
|
||||
cumulative_false_positives = [[]]
|
||||
|
||||
@ -179,6 +181,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
|
||||
# If the image doesn't contain any objects of this class,
|
||||
# the prediction becomes a false positive.
|
||||
false_pos[i] = 1
|
||||
open_set_error += 1
|
||||
continue
|
||||
|
||||
# Compute the IoU of this prediction with all ground truth boxes of the same class.
|
||||
@ -229,7 +232,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
|
||||
cumulative_true_positives.append(cumulative_true_pos)
|
||||
cumulative_false_positives.append(cumulative_false_pos)
|
||||
|
||||
return true_positives, false_positives, cumulative_true_positives, cumulative_false_positives
|
||||
return true_positives, false_positives, cumulative_true_positives, cumulative_false_positives, open_set_error
|
||||
|
||||
|
||||
def get_precision_recall(number_gt_per_class: np.ndarray,
|
||||
|
||||
@ -239,7 +239,7 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
||||
|
||||
# compute matches between predictions and ground truth
|
||||
true_positives, false_positives, \
|
||||
cum_true_positives, cum_false_positives = evaluate.match_predictions(predictions_per_class,
|
||||
cum_true_positives, cum_false_positives, open_set_error = evaluate.match_predictions(predictions_per_class,
|
||||
labels,
|
||||
ssd.N_CLASSES)
|
||||
del labels
|
||||
|
||||
Reference in New Issue
Block a user