diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index 2061dcc..c6a1d72 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -353,9 +353,9 @@ def _ssd_evaluate(args: argparse.Namespace) -> None: true_positives, false_positives, \ cum_true_positives, cum_false_positives, \ - open_set_error = evaluate.match_predictions(predictions_per_class, labels, - bounding_box_utils.iou, - nr_classes, iou_threshold) + open_set_error, cumulative_open_set_error = evaluate.match_predictions(predictions_per_class, labels, + bounding_box_utils.iou, + nr_classes, iou_threshold) cum_precisions, cum_recalls = evaluate.get_precision_recall(number_gt_per_class, cum_true_positives, @@ -375,7 +375,8 @@ def _ssd_evaluate(args: argparse.Namespace) -> None: f1_scores, average_precisions, mean_average_precision, - open_set_error) + open_set_error, + cumulative_open_set_error) _pickle(result_file, results) @@ -829,7 +830,8 @@ def _ssd_evaluate_get_results(true_positives: Sequence[np.ndarray], f1_scores: Sequence[np.ndarray], average_precisions: Sequence[float], mean_average_precision: float, - open_set_error: int + open_set_error: np.ndarray, + cumulative_open_set_error: np.ndarray ) -> Dict[str, Union[np.ndarray, float, int]]: results = { "true_positives": true_positives, @@ -841,7 +843,8 @@ def _ssd_evaluate_get_results(true_positives: Sequence[np.ndarray], "f1_scores": f1_scores, "mean_average_precisions": average_precisions, "mean_average_precision": mean_average_precision, - "open_set_error": open_set_error + "open_set_error": open_set_error, + "cumulative_open_set_error": cumulative_open_set_error } return results diff --git a/src/twomartens/masterthesis/evaluate.py b/src/twomartens/masterthesis/evaluate.py index 6b45b7c..58a7823 100644 --- a/src/twomartens/masterthesis/evaluate.py +++ b/src/twomartens/masterthesis/evaluate.py @@ -99,7 +99,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in border_pixels: str = "include", sorting_algorithm: str = "quicksort") -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray], - int]: + np.ndarray, np.ndarray]: """ Matches predictions to ground truth boxes. @@ -124,13 +124,20 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in Returns: true positives, false positives, cumulative true positives, and cumulative false positives for - each class, open set error as defined by Miller et al + each class, open set error as defined by Miller et al, cumulative open set error """ true_positives = [[]] # The false positives for each class, sorted by descending confidence. false_positives = [[]] # The true positives for each class, sorted by descending confidence. - open_set_error = 0 cumulative_true_positives = [[]] cumulative_false_positives = [[]] + most_predictions = -1 + + for class_id in range(1, nr_classes + 1): + nr_predictions = len(predictions[class_id]) + if nr_predictions > most_predictions: + most_predictions = nr_predictions + + open_set_error = np.zeros(most_predictions, dtype=np.int) for class_id in range(1, nr_classes + 1): predictions_class = predictions[class_id] @@ -189,7 +196,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in # If the image doesn't contain any objects of this class, # the prediction becomes a false positive. false_pos[i] = 1 - open_set_error += 1 + open_set_error[i] += 1 continue # Compute the IoU of this prediction with all ground truth boxes of the same class. @@ -240,7 +247,12 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in cumulative_true_positives.append(cumulative_true_pos) cumulative_false_positives.append(cumulative_false_pos) - return true_positives, false_positives, cumulative_true_positives, cumulative_false_positives, open_set_error + cumulative_open_set_error = np.cumsum(open_set_error) + + return ( + true_positives, false_positives, cumulative_true_positives, cumulative_false_positives, + open_set_error, cumulative_open_set_error + ) def get_precision_recall(number_gt_per_class: np.ndarray,