diff --git a/src/twomartens/masterthesis/evaluate.py b/src/twomartens/masterthesis/evaluate.py index 8eeb65a..a73d91c 100644 --- a/src/twomartens/masterthesis/evaluate.py +++ b/src/twomartens/masterthesis/evaluate.py @@ -305,7 +305,7 @@ def get_precision_recall(number_gt_per_class: np.ndarray, # Iterate over all classes. for class_id in range(1, nr_classes + 1): - if number_gt_per_class[class_id] == 0 or cumulative_true_positives[class_id].shape[0] == 0: + if number_gt_per_class[class_id] == 0: cumulative_precisions.append([]) cumulative_recalls.append([]) continue @@ -323,8 +323,10 @@ def get_precision_recall(number_gt_per_class: np.ndarray, diff_to_largest_class = cumulative_precision_micro.shape[0] - cumulative_precision.shape[0] if diff_to_largest_class: - repeated_last_precision = np.tile(cumulative_precision[-1], diff_to_largest_class) - repeated_last_recall = np.tile(cumulative_recall[-1], diff_to_largest_class) + highest_precision = cumulative_precision[-1] if cumulative_precision.shape[0] else 0 + highest_recall = cumulative_recall[-1] if cumulative_recall.shape[0] else 0 + repeated_last_precision = np.tile(highest_precision, diff_to_largest_class) + repeated_last_recall = np.tile(highest_recall, diff_to_largest_class) extended_precision = np.concatenate((cumulative_precision, repeated_last_precision)) extended_recall = np.concatenate((cumulative_recall, repeated_last_recall)) cumulative_precision_macro += extended_precision