diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index ff01f95..331af4a 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -475,6 +475,25 @@ def _visualise_metrics(visualise_precision_recall: callable, output_path, f"macro-{threshold}" if use_entropy_threshold else "macro") + precision = metrics["cumulative_precisions"] + recall = metrics["cumulative_recalls"] + f1_scores = metrics["f1_scores"] + + class_scores = {} + for i in range(1, conf_obj.parameters.nr_classes + 1): + if not len(f1_scores[i]): + continue + + max_f1_score_index = np.argmax(f1_scores[i]) + max_f1_score = f1_scores[i][max_f1_score_index] + precision_at_max_f1 = precision[i][max_f1_score_index] + recall_at_max_f1 = recall[i][max_f1_score_index] + class_scores[i] = { + "max_f1_score": max_f1_score, + "precision_at_max_f1": precision_at_max_f1, + "recall_at_max_f1": recall_at_max_f1 + } + max_f1_score_micro_index = np.argmax(f1_scores_micro, axis=0) max_f1_score_micro = f1_scores_micro[max_f1_score_micro_index] precision_at_max_f1_micro = precision_micro[max_f1_score_micro_index] @@ -500,6 +519,7 @@ def _visualise_metrics(visualise_precision_recall: callable, "precision_at_max_f1_macro": precision_at_max_f1_macro, "recall_at_max_f1_macro": recall_at_max_f1_macro, "ose_at_max_f1_macro": int(ose_at_max_f1_macro), + "class_scores": class_scores }, file, indent=2)