Implemented f1 score calculation
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -20,6 +20,7 @@ Functionality to evaluate results of networks.
|
|||||||
|
|
||||||
Functions:
|
Functions:
|
||||||
get_number_gt_per_class(...): calculates the number of ground truth boxes per class
|
get_number_gt_per_class(...): calculates the number of ground truth boxes per class
|
||||||
|
get_f1_score(...): computes the F1 score for every class
|
||||||
match_predictions(...): matches predictions to ground truth boxes
|
match_predictions(...): matches predictions to ground truth boxes
|
||||||
"""
|
"""
|
||||||
from typing import Sequence, Union, Tuple, List
|
from typing import Sequence, Union, Tuple, List
|
||||||
@ -199,9 +200,6 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int,
|
|||||||
# false positives.
|
# false positives.
|
||||||
false_pos[i] = 1
|
false_pos[i] = 1
|
||||||
else:
|
else:
|
||||||
# If this is not a ground truth that is supposed to be evaluation-neutral
|
|
||||||
# (i.e. should be skipped for the evaluation) or if we don't even have the
|
|
||||||
# concept of neutral boxes.
|
|
||||||
if image_id not in gt_matched:
|
if image_id not in gt_matched:
|
||||||
# True positive:
|
# True positive:
|
||||||
# If the matched ground truth box for this prediction hasn't been matched to a
|
# If the matched ground truth box for this prediction hasn't been matched to a
|
||||||
@ -268,6 +266,32 @@ def get_precision_recall(number_gt_per_class: np.ndarray,
|
|||||||
return cumulative_precisions, cumulative_recalls
|
return cumulative_precisions, cumulative_recalls
|
||||||
|
|
||||||
|
|
||||||
|
def get_f1_score(cumulative_precisions: List[np.ndarray],
|
||||||
|
cumulative_recalls: List[np.ndarray],
|
||||||
|
nr_classes: int) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Computes the F1 score for every class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cumulative_precisions: cumulative precisions for each class
|
||||||
|
cumulative_recalls: cumulative recalls for each class
|
||||||
|
nr_classes: number of classes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cumulative F1 score per class
|
||||||
|
"""
|
||||||
|
cumulative_f1_scores = [[]]
|
||||||
|
|
||||||
|
# iterate over all classes
|
||||||
|
for class_id in range(1, nr_classes + 1):
|
||||||
|
cumulative_precision = cumulative_precisions[class_id]
|
||||||
|
cumulative_recall = cumulative_recalls[class_id]
|
||||||
|
f1_score = 2 * ((cumulative_precision * cumulative_recall) / (cumulative_precision + cumulative_recall))
|
||||||
|
cumulative_f1_scores.append(f1_score)
|
||||||
|
|
||||||
|
return cumulative_f1_scores
|
||||||
|
|
||||||
|
|
||||||
def get_mean_average_precisions(cumulative_precisions: List[np.ndarray],
|
def get_mean_average_precisions(cumulative_precisions: List[np.ndarray],
|
||||||
cumulative_recalls: List[np.ndarray],
|
cumulative_recalls: List[np.ndarray],
|
||||||
nr_classes: int) -> List[float]:
|
nr_classes: int) -> List[float]:
|
||||||
|
|||||||
@ -247,6 +247,7 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||||||
cum_true_positives,
|
cum_true_positives,
|
||||||
cum_false_positives,
|
cum_false_positives,
|
||||||
ssd.N_CLASSES)
|
ssd.N_CLASSES)
|
||||||
|
f1_scores = evaluate.get_f1_score(cum_precisions, cum_recalls, ssd.N_CLASSES)
|
||||||
average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES)
|
average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES)
|
||||||
mean_average_precision = evaluate.get_mean_average_precision(average_precisions)
|
mean_average_precision = evaluate.get_mean_average_precision(average_precisions)
|
||||||
|
|
||||||
@ -257,7 +258,8 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||||||
"cumulative_false_positives": cum_false_positives,
|
"cumulative_false_positives": cum_false_positives,
|
||||||
"cumulative_precisions": cum_precisions,
|
"cumulative_precisions": cum_precisions,
|
||||||
"cumulative_recalls": cum_recalls,
|
"cumulative_recalls": cum_recalls,
|
||||||
"average precisions": average_precisions,
|
"f1_scores": f1_scores,
|
||||||
|
"mean_average_precisions": average_precisions,
|
||||||
"mean_average_precision": mean_average_precision
|
"mean_average_precision": mean_average_precision
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user