Added calculation of open set error

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-05-21 11:13:04 +02:00
parent 160346b4cb
commit e384534fb9
2 changed files with 9 additions and 6 deletions

View File

@ -96,7 +96,8 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
iou_threshold: float = 0.5, iou_threshold: float = 0.5,
border_pixels: str = "include", border_pixels: str = "include",
sorting_algorithm: str = "quicksort") -> Tuple[List[np.ndarray], List[np.ndarray], sorting_algorithm: str = "quicksort") -> Tuple[List[np.ndarray], List[np.ndarray],
List[np.ndarray], List[np.ndarray]]: List[np.ndarray], List[np.ndarray],
int]:
""" """
Matches predictions to ground truth boxes. Matches predictions to ground truth boxes.
@ -120,10 +121,11 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
Returns: Returns:
true positives, false positives, cumulative true positives, and cumulative false positives for true positives, false positives, cumulative true positives, and cumulative false positives for
each class each class, open set error as defined by Miller et al
""" """
true_positives = [[]] # The false positives for each class, sorted by descending confidence. true_positives = [[]] # The false positives for each class, sorted by descending confidence.
false_positives = [[]] # The true positives for each class, sorted by descending confidence. false_positives = [[]] # The true positives for each class, sorted by descending confidence.
open_set_error = 0
cumulative_true_positives = [[]] cumulative_true_positives = [[]]
cumulative_false_positives = [[]] cumulative_false_positives = [[]]
@ -179,6 +181,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
# If the image doesn't contain any objects of this class, # If the image doesn't contain any objects of this class,
# the prediction becomes a false positive. # the prediction becomes a false positive.
false_pos[i] = 1 false_pos[i] = 1
open_set_error += 1
continue continue
# Compute the IoU of this prediction with all ground truth boxes of the same class. # Compute the IoU of this prediction with all ground truth boxes of the same class.
@ -229,7 +232,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
cumulative_true_positives.append(cumulative_true_pos) cumulative_true_positives.append(cumulative_true_pos)
cumulative_false_positives.append(cumulative_false_pos) cumulative_false_positives.append(cumulative_false_pos)
return true_positives, false_positives, cumulative_true_positives, cumulative_false_positives return true_positives, false_positives, cumulative_true_positives, cumulative_false_positives, open_set_error
def get_precision_recall(number_gt_per_class: np.ndarray, def get_precision_recall(number_gt_per_class: np.ndarray,

View File

@ -239,7 +239,7 @@ def _ssd_test(args: argparse.Namespace) -> None:
# compute matches between predictions and ground truth # compute matches between predictions and ground truth
true_positives, false_positives, \ true_positives, false_positives, \
cum_true_positives, cum_false_positives = evaluate.match_predictions(predictions_per_class, cum_true_positives, cum_false_positives, open_set_error = evaluate.match_predictions(predictions_per_class,
labels, labels,
ssd.N_CLASSES) ssd.N_CLASSES)
del labels del labels