diff --git a/src/twomartens/masterthesis/evaluate.py b/src/twomartens/masterthesis/evaluate.py index 26934b5..5016a9c 100644 --- a/src/twomartens/masterthesis/evaluate.py +++ b/src/twomartens/masterthesis/evaluate.py @@ -20,6 +20,7 @@ Functionality to evaluate results of networks. Functions: get_number_gt_per_class(...): calculates the number of ground truth boxes per class + get_f1_score(...): computes the F1 score for every class match_predictions(...): matches predictions to ground truth boxes """ from typing import Sequence, Union, Tuple, List @@ -199,9 +200,6 @@ def match_predictions(predictions: Sequence[Sequence[Tuple[int, float, int, int, # false positives. false_pos[i] = 1 else: - # If this is not a ground truth that is supposed to be evaluation-neutral - # (i.e. should be skipped for the evaluation) or if we don't even have the - # concept of neutral boxes. if image_id not in gt_matched: # True positive: # If the matched ground truth box for this prediction hasn't been matched to a @@ -268,6 +266,32 @@ def get_precision_recall(number_gt_per_class: np.ndarray, return cumulative_precisions, cumulative_recalls +def get_f1_score(cumulative_precisions: List[np.ndarray], + cumulative_recalls: List[np.ndarray], + nr_classes: int) -> List[np.ndarray]: + """ + Computes the F1 score for every class. + + Args: + cumulative_precisions: cumulative precisions for each class + cumulative_recalls: cumulative recalls for each class + nr_classes: number of classes + + Returns: + cumulative F1 score per class + """ + cumulative_f1_scores = [[]] + + # iterate over all classes + for class_id in range(1, nr_classes + 1): + cumulative_precision = cumulative_precisions[class_id] + cumulative_recall = cumulative_recalls[class_id] + f1_score = 2 * ((cumulative_precision * cumulative_recall) / (cumulative_precision + cumulative_recall)) + cumulative_f1_scores.append(f1_score) + + return cumulative_f1_scores + + def get_mean_average_precisions(cumulative_precisions: List[np.ndarray], cumulative_recalls: List[np.ndarray], nr_classes: int) -> List[float]: diff --git a/src/twomartens/masterthesis/main.py b/src/twomartens/masterthesis/main.py index 1f84c80..5d1dd68 100644 --- a/src/twomartens/masterthesis/main.py +++ b/src/twomartens/masterthesis/main.py @@ -247,6 +247,7 @@ def _ssd_test(args: argparse.Namespace) -> None: cum_true_positives, cum_false_positives, ssd.N_CLASSES) + f1_scores = evaluate.get_f1_score(cum_precisions, cum_recalls, ssd.N_CLASSES) average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES) mean_average_precision = evaluate.get_mean_average_precision(average_precisions) @@ -257,7 +258,8 @@ def _ssd_test(args: argparse.Namespace) -> None: "cumulative_false_positives": cum_false_positives, "cumulative_precisions": cum_precisions, "cumulative_recalls": cum_recalls, - "average precisions": average_precisions, + "f1_scores": f1_scores, + "mean_average_precisions": average_precisions, "mean_average_precision": mean_average_precision }