Implemented cumulative open set error
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -353,7 +353,7 @@ def _ssd_evaluate(args: argparse.Namespace) -> None:
|
|||||||
|
|
||||||
true_positives, false_positives, \
|
true_positives, false_positives, \
|
||||||
cum_true_positives, cum_false_positives, \
|
cum_true_positives, cum_false_positives, \
|
||||||
open_set_error = evaluate.match_predictions(predictions_per_class, labels,
|
open_set_error, cumulative_open_set_error = evaluate.match_predictions(predictions_per_class, labels,
|
||||||
bounding_box_utils.iou,
|
bounding_box_utils.iou,
|
||||||
nr_classes, iou_threshold)
|
nr_classes, iou_threshold)
|
||||||
|
|
||||||
@ -375,7 +375,8 @@ def _ssd_evaluate(args: argparse.Namespace) -> None:
|
|||||||
f1_scores,
|
f1_scores,
|
||||||
average_precisions,
|
average_precisions,
|
||||||
mean_average_precision,
|
mean_average_precision,
|
||||||
open_set_error)
|
open_set_error,
|
||||||
|
cumulative_open_set_error)
|
||||||
|
|
||||||
_pickle(result_file, results)
|
_pickle(result_file, results)
|
||||||
|
|
||||||
@ -829,7 +830,8 @@ def _ssd_evaluate_get_results(true_positives: Sequence[np.ndarray],
|
|||||||
f1_scores: Sequence[np.ndarray],
|
f1_scores: Sequence[np.ndarray],
|
||||||
average_precisions: Sequence[float],
|
average_precisions: Sequence[float],
|
||||||
mean_average_precision: float,
|
mean_average_precision: float,
|
||||||
open_set_error: int
|
open_set_error: np.ndarray,
|
||||||
|
cumulative_open_set_error: np.ndarray
|
||||||
) -> Dict[str, Union[np.ndarray, float, int]]:
|
) -> Dict[str, Union[np.ndarray, float, int]]:
|
||||||
results = {
|
results = {
|
||||||
"true_positives": true_positives,
|
"true_positives": true_positives,
|
||||||
@ -841,7 +843,8 @@ def _ssd_evaluate_get_results(true_positives: Sequence[np.ndarray],
|
|||||||
"f1_scores": f1_scores,
|
"f1_scores": f1_scores,
|
||||||
"mean_average_precisions": average_precisions,
|
"mean_average_precisions": average_precisions,
|
||||||
"mean_average_precision": mean_average_precision,
|
"mean_average_precision": mean_average_precision,
|
||||||
"open_set_error": open_set_error
|
"open_set_error": open_set_error,
|
||||||
|
"cumulative_open_set_error": cumulative_open_set_error
|
||||||
}
|
}
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@ -99,7 +99,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in
|
|||||||
border_pixels: str = "include",
|
border_pixels: str = "include",
|
||||||
sorting_algorithm: str = "quicksort") -> Tuple[List[np.ndarray], List[np.ndarray],
|
sorting_algorithm: str = "quicksort") -> Tuple[List[np.ndarray], List[np.ndarray],
|
||||||
List[np.ndarray], List[np.ndarray],
|
List[np.ndarray], List[np.ndarray],
|
||||||
int]:
|
np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Matches predictions to ground truth boxes.
|
Matches predictions to ground truth boxes.
|
||||||
|
|
||||||
@ -124,13 +124,20 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
true positives, false positives, cumulative true positives, and cumulative false positives for
|
true positives, false positives, cumulative true positives, and cumulative false positives for
|
||||||
each class, open set error as defined by Miller et al
|
each class, open set error as defined by Miller et al, cumulative open set error
|
||||||
"""
|
"""
|
||||||
true_positives = [[]] # The false positives for each class, sorted by descending confidence.
|
true_positives = [[]] # The false positives for each class, sorted by descending confidence.
|
||||||
false_positives = [[]] # The true positives for each class, sorted by descending confidence.
|
false_positives = [[]] # The true positives for each class, sorted by descending confidence.
|
||||||
open_set_error = 0
|
|
||||||
cumulative_true_positives = [[]]
|
cumulative_true_positives = [[]]
|
||||||
cumulative_false_positives = [[]]
|
cumulative_false_positives = [[]]
|
||||||
|
most_predictions = -1
|
||||||
|
|
||||||
|
for class_id in range(1, nr_classes + 1):
|
||||||
|
nr_predictions = len(predictions[class_id])
|
||||||
|
if nr_predictions > most_predictions:
|
||||||
|
most_predictions = nr_predictions
|
||||||
|
|
||||||
|
open_set_error = np.zeros(most_predictions, dtype=np.int)
|
||||||
|
|
||||||
for class_id in range(1, nr_classes + 1):
|
for class_id in range(1, nr_classes + 1):
|
||||||
predictions_class = predictions[class_id]
|
predictions_class = predictions[class_id]
|
||||||
@ -189,7 +196,7 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in
|
|||||||
# If the image doesn't contain any objects of this class,
|
# If the image doesn't contain any objects of this class,
|
||||||
# the prediction becomes a false positive.
|
# the prediction becomes a false positive.
|
||||||
false_pos[i] = 1
|
false_pos[i] = 1
|
||||||
open_set_error += 1
|
open_set_error[i] += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Compute the IoU of this prediction with all ground truth boxes of the same class.
|
# Compute the IoU of this prediction with all ground truth boxes of the same class.
|
||||||
@ -240,7 +247,12 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, float, in
|
|||||||
cumulative_true_positives.append(cumulative_true_pos)
|
cumulative_true_positives.append(cumulative_true_pos)
|
||||||
cumulative_false_positives.append(cumulative_false_pos)
|
cumulative_false_positives.append(cumulative_false_pos)
|
||||||
|
|
||||||
return true_positives, false_positives, cumulative_true_positives, cumulative_false_positives, open_set_error
|
cumulative_open_set_error = np.cumsum(open_set_error)
|
||||||
|
|
||||||
|
return (
|
||||||
|
true_positives, false_positives, cumulative_true_positives, cumulative_false_positives,
|
||||||
|
open_set_error, cumulative_open_set_error
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_precision_recall(number_gt_per_class: np.ndarray,
|
def get_precision_recall(number_gt_per_class: np.ndarray,
|
||||||
|
|||||||
Reference in New Issue
Block a user