Implemented cumulative open set error
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
parent
28a0d35d36
commit
1db9fa1c0a
|
@ -353,9 +353,9 @@ def _ssd_evaluate(args: argparse.Namespace) -> None:
|
|||
|
||||
true_positives, false_positives, \
|
||||
cum_true_positives, cum_false_positives, \
|
||||
open_set_error = evaluate.match_predictions(predictions_per_class, labels,
|
||||
bounding_box_utils.iou,
|
||||
nr_classes, iou_threshold)
|
||||
open_set_error, cumulative_open_set_error = evaluate.match_predictions(predictions_per_class, labels,
|
||||
bounding_box_utils.iou,
|
||||
nr_classes, iou_threshold)
|
||||
|
||||
cum_precisions, cum_recalls = evaluate.get_precision_recall(number_gt_per_class,
|
||||
cum_true_positives,
|
||||
|
@ -375,7 +375,8 @@ def _ssd_evaluate(args: argparse.Namespace) -> None:
|
|||
f1_scores,
|
||||
average_precisions,
|
||||
mean_average_precision,
|
||||
open_set_error)
|
||||
open_set_error,
|
||||
cumulative_open_set_error)
|
||||
|
||||
_pickle(result_file, results)
|
||||
|
||||
|
@ -829,7 +830,8 @@ def _ssd_evaluate_get_results(true_positives: Sequence[np.ndarray],
|
|||
f1_scores: Sequence[np.ndarray],
|
||||
average_precisions: Sequence[float],
|
||||
mean_average_precision: float,
|
||||
open_set_error: int
|
||||
open_set_error: np.ndarray,
|
||||
cumulative_open_set_error: np.ndarray
|
||||
) -> Dict[str, Union[np.ndarray, float, int]]:
|
||||
results = {
|
||||
"true_positives": true_positives,
|
||||
|
@ -841,7 +843,8 @@ def _ssd_evaluate_get_results(true_positives: Sequence[np.ndarray],
|
|||
"f1_scores": f1_scores,
|
||||
"mean_average_precisions": average_precisions,
|
||||
"mean_average_precision": mean_average_precision,
|
||||
"open_set_error": open_set_error
|
||||
"open_set_error": open_set_error,
|
||||
"cumulative_open_set_error": cumulative_open_set_error
|
||||
}
|
||||
|
||||
return results
|
||||
|
|
|
@ -99,7 +99,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in
|
|||
border_pixels: str = "include",
|
||||
sorting_algorithm: str = "quicksort") -> Tuple[List[np.ndarray], List[np.ndarray],
|
||||
List[np.ndarray], List[np.ndarray],
|
||||
int]:
|
||||
np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Matches predictions to ground truth boxes.
|
||||
|
||||
|
@ -124,13 +124,20 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in
|
|||
|
||||
Returns:
|
||||
true positives, false positives, cumulative true positives, and cumulative false positives for
|
||||
each class, open set error as defined by Miller et al
|
||||
each class, open set error as defined by Miller et al, cumulative open set error
|
||||
"""
|
||||
true_positives = [[]] # The false positives for each class, sorted by descending confidence.
|
||||
false_positives = [[]] # The true positives for each class, sorted by descending confidence.
|
||||
open_set_error = 0
|
||||
cumulative_true_positives = [[]]
|
||||
cumulative_false_positives = [[]]
|
||||
most_predictions = -1
|
||||
|
||||
for class_id in range(1, nr_classes + 1):
|
||||
nr_predictions = len(predictions[class_id])
|
||||
if nr_predictions > most_predictions:
|
||||
most_predictions = nr_predictions
|
||||
|
||||
open_set_error = np.zeros(most_predictions, dtype=np.int)
|
||||
|
||||
for class_id in range(1, nr_classes + 1):
|
||||
predictions_class = predictions[class_id]
|
||||
|
@ -189,7 +196,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in
|
|||
# If the image doesn't contain any objects of this class,
|
||||
# the prediction becomes a false positive.
|
||||
false_pos[i] = 1
|
||||
open_set_error += 1
|
||||
open_set_error[i] += 1
|
||||
continue
|
||||
|
||||
# Compute the IoU of this prediction with all ground truth boxes of the same class.
|
||||
|
@ -240,7 +247,12 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in
|
|||
cumulative_true_positives.append(cumulative_true_pos)
|
||||
cumulative_false_positives.append(cumulative_false_pos)
|
||||
|
||||
return true_positives, false_positives, cumulative_true_positives, cumulative_false_positives, open_set_error
|
||||
cumulative_open_set_error = np.cumsum(open_set_error)
|
||||
|
||||
return (
|
||||
true_positives, false_positives, cumulative_true_positives, cumulative_false_positives,
|
||||
open_set_error, cumulative_open_set_error
|
||||
)
|
||||
|
||||
|
||||
def get_precision_recall(number_gt_per_class: np.ndarray,
|
||||
|
|
Loading…
Reference in New Issue