From e384534fb912cd9e5f918727572141830b10d328 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Tue, 21 May 2019 11:13:04 +0200 Subject: [PATCH] Added calculation of open set error Signed-off-by: Jim Martens --- src/twomartens/masterthesis/evaluate.py | 9 ++++++--- src/twomartens/masterthesis/main.py | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/twomartens/masterthesis/evaluate.py b/src/twomartens/masterthesis/evaluate.py index 5016a9c..b0b93f2 100644 --- a/src/twomartens/masterthesis/evaluate.py +++ b/src/twomartens/masterthesis/evaluate.py @@ -96,7 +96,8 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int, iou_threshold: float = 0.5, border_pixels: str = "include", sorting_algorithm: str = "quicksort") -> Tuple[List[np.ndarray], List[np.ndarray], - List[np.ndarray], List[np.ndarray]]: + List[np.ndarray], List[np.ndarray], + int]: """ Matches predictions to ground truth boxes. @@ -120,10 +121,11 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int, Returns: true positives, false positives, cumulative true positives, and cumulative false positives for - each class + each class, open set error as defined by Miller et al """ true_positives = [[]] # The false positives for each class, sorted by descending confidence. false_positives = [[]] # The true positives for each class, sorted by descending confidence. + open_set_error = 0 cumulative_true_positives = [[]] cumulative_false_positives = [[]] @@ -179,6 +181,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int, # If the image doesn't contain any objects of this class, # the prediction becomes a false positive. false_pos[i] = 1 + open_set_error += 1 continue # Compute the IoU of this prediction with all ground truth boxes of the same class. @@ -229,7 +232,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int, cumulative_true_positives.append(cumulative_true_pos) cumulative_false_positives.append(cumulative_false_pos) - return true_positives, false_positives, cumulative_true_positives, cumulative_false_positives + return true_positives, false_positives, cumulative_true_positives, cumulative_false_positives, open_set_error def get_precision_recall(number_gt_per_class: np.ndarray, diff --git a/src/twomartens/masterthesis/main.py b/src/twomartens/masterthesis/main.py index 5d1dd68..886e470 100644 --- a/src/twomartens/masterthesis/main.py +++ b/src/twomartens/masterthesis/main.py @@ -239,9 +239,9 @@ def _ssd_test(args: argparse.Namespace) -> None: # compute matches between predictions and ground truth true_positives, false_positives, \ - cum_true_positives, cum_false_positives = evaluate.match_predictions(predictions_per_class, - labels, - ssd.N_CLASSES) + cum_true_positives, cum_false_positives, open_set_error = evaluate.match_predictions(predictions_per_class, + labels, + ssd.N_CLASSES) del labels cum_precisions, cum_recalls = evaluate.get_precision_recall(number_gt_per_class, cum_true_positives,