Implemented f1 score calculation
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -20,6 +20,7 @@ Functionality to evaluate results of networks.
|
||||
|
||||
Functions:
|
||||
get_number_gt_per_class(...): calculates the number of ground truth boxes per class
|
||||
get_f1_score(...): computes the F1 score for every class
|
||||
match_predictions(...): matches predictions to ground truth boxes
|
||||
"""
|
||||
from typing import Sequence, Union, Tuple, List
|
||||
@ -199,9 +200,6 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
|
||||
# false positives.
|
||||
false_pos[i] = 1
|
||||
else:
|
||||
# If this is not a ground truth that is supposed to be evaluation-neutral
|
||||
# (i.e. should be skipped for the evaluation) or if we don't even have the
|
||||
# concept of neutral boxes.
|
||||
if image_id not in gt_matched:
|
||||
# True positive:
|
||||
# If the matched ground truth box for this prediction hasn't been matched to a
|
||||
@ -268,6 +266,32 @@ def get_precision_recall(number_gt_per_class: np.ndarray,
|
||||
return cumulative_precisions, cumulative_recalls
|
||||
|
||||
|
||||
def get_f1_score(cumulative_precisions: List[np.ndarray],
|
||||
cumulative_recalls: List[np.ndarray],
|
||||
nr_classes: int) -> List[np.ndarray]:
|
||||
"""
|
||||
Computes the F1 score for every class.
|
||||
|
||||
Args:
|
||||
cumulative_precisions: cumulative precisions for each class
|
||||
cumulative_recalls: cumulative recalls for each class
|
||||
nr_classes: number of classes
|
||||
|
||||
Returns:
|
||||
cumulative F1 score per class
|
||||
"""
|
||||
cumulative_f1_scores = [[]]
|
||||
|
||||
# iterate over all classes
|
||||
for class_id in range(1, nr_classes + 1):
|
||||
cumulative_precision = cumulative_precisions[class_id]
|
||||
cumulative_recall = cumulative_recalls[class_id]
|
||||
f1_score = 2 * ((cumulative_precision * cumulative_recall) / (cumulative_precision + cumulative_recall))
|
||||
cumulative_f1_scores.append(f1_score)
|
||||
|
||||
return cumulative_f1_scores
|
||||
|
||||
|
||||
def get_mean_average_precisions(cumulative_precisions: List[np.ndarray],
|
||||
cumulative_recalls: List[np.ndarray],
|
||||
nr_classes: int) -> List[float]:
|
||||
|
||||
@ -247,6 +247,7 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
||||
cum_true_positives,
|
||||
cum_false_positives,
|
||||
ssd.N_CLASSES)
|
||||
f1_scores = evaluate.get_f1_score(cum_precisions, cum_recalls, ssd.N_CLASSES)
|
||||
average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES)
|
||||
mean_average_precision = evaluate.get_mean_average_precision(average_precisions)
|
||||
|
||||
@ -257,7 +258,8 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
||||
"cumulative_false_positives": cum_false_positives,
|
||||
"cumulative_precisions": cum_precisions,
|
||||
"cumulative_recalls": cum_recalls,
|
||||
"average precisions": average_precisions,
|
||||
"f1_scores": f1_scores,
|
||||
"mean_average_precisions": average_precisions,
|
||||
"mean_average_precision": mean_average_precision
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user