Implemented f1 score calculation

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-05-16 15:37:00 +02:00
parent f04f71bc50
commit 160346b4cb
2 changed files with 30 additions and 4 deletions

View File

@ -20,6 +20,7 @@ Functionality to evaluate results of networks.
Functions: Functions:
get_number_gt_per_class(...): calculates the number of ground truth boxes per class get_number_gt_per_class(...): calculates the number of ground truth boxes per class
get_f1_score(...): computes the F1 score for every class
match_predictions(...): matches predictions to ground truth boxes match_predictions(...): matches predictions to ground truth boxes
""" """
from typing import Sequence, Union, Tuple, List from typing import Sequence, Union, Tuple, List
@ -199,9 +200,6 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
# false positives. # false positives.
false_pos[i] = 1 false_pos[i] = 1
else: else:
# If this is not a ground truth that is supposed to be evaluation-neutral
# (i.e. should be skipped for the evaluation) or if we don't even have the
# concept of neutral boxes.
if image_id not in gt_matched: if image_id not in gt_matched:
# True positive: # True positive:
# If the matched ground truth box for this prediction hasn't been matched to a # If the matched ground truth box for this prediction hasn't been matched to a
@ -268,6 +266,32 @@ def get_precision_recall(number_gt_per_class: np.ndarray,
return cumulative_precisions, cumulative_recalls return cumulative_precisions, cumulative_recalls
def get_f1_score(cumulative_precisions: List[np.ndarray],
cumulative_recalls: List[np.ndarray],
nr_classes: int) -> List[np.ndarray]:
"""
Computes the F1 score for every class.
Args:
cumulative_precisions: cumulative precisions for each class
cumulative_recalls: cumulative recalls for each class
nr_classes: number of classes
Returns:
cumulative F1 score per class
"""
cumulative_f1_scores = [[]]
# iterate over all classes
for class_id in range(1, nr_classes + 1):
cumulative_precision = cumulative_precisions[class_id]
cumulative_recall = cumulative_recalls[class_id]
f1_score = 2 * ((cumulative_precision * cumulative_recall) / (cumulative_precision + cumulative_recall))
cumulative_f1_scores.append(f1_score)
return cumulative_f1_scores
def get_mean_average_precisions(cumulative_precisions: List[np.ndarray], def get_mean_average_precisions(cumulative_precisions: List[np.ndarray],
cumulative_recalls: List[np.ndarray], cumulative_recalls: List[np.ndarray],
nr_classes: int) -> List[float]: nr_classes: int) -> List[float]:

View File

@ -247,6 +247,7 @@ def _ssd_test(args: argparse.Namespace) -> None:
cum_true_positives, cum_true_positives,
cum_false_positives, cum_false_positives,
ssd.N_CLASSES) ssd.N_CLASSES)
f1_scores = evaluate.get_f1_score(cum_precisions, cum_recalls, ssd.N_CLASSES)
average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES) average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES)
mean_average_precision = evaluate.get_mean_average_precision(average_precisions) mean_average_precision = evaluate.get_mean_average_precision(average_precisions)
@ -257,7 +258,8 @@ def _ssd_test(args: argparse.Namespace) -> None:
"cumulative_false_positives": cum_false_positives, "cumulative_false_positives": cum_false_positives,
"cumulative_precisions": cum_precisions, "cumulative_precisions": cum_precisions,
"cumulative_recalls": cum_recalls, "cumulative_recalls": cum_recalls,
"average precisions": average_precisions, "f1_scores": f1_scores,
"mean_average_precisions": average_precisions,
"mean_average_precision": mean_average_precision "mean_average_precision": mean_average_precision
} }