From 9d726ebb10e031ac75557dac8cd8c78baa193603 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Thu, 16 May 2019 14:07:20 +0200 Subject: [PATCH] Finished vanilla SSD evaluation code Signed-off-by: Jim Martens --- src/twomartens/masterthesis/main.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/twomartens/masterthesis/main.py b/src/twomartens/masterthesis/main.py index 922efd7..4f0d582 100644 --- a/src/twomartens/masterthesis/main.py +++ b/src/twomartens/masterthesis/main.py @@ -174,7 +174,11 @@ def _ssd_test(args: argparse.Namespace) -> None: batch_size = 16 use_dropout = False if args.network == "ssd" else True output_path = f"{args.output_path}/val/{args.network}/{args.iteration}" + evaluation_path = f"{args.evaluation_path}/{args.network}" + result_file = f"{evaluation_path}/results-{args.iteration}.bin" label_file = f"{output_path}/labels.bin" + predictions_file = f"{output_path}/predictions.bin" + predictions_per_class_file = f"{output_path}/predictions_class.bin" # retrieve labels and un-batch them files = glob.glob(f"{output_path}/*ssd_labels*") @@ -205,9 +209,15 @@ def _ssd_test(args: argparse.Namespace) -> None: del _predictions # prepare predictions for further use + with open(predictions_file, "wb") as file: + pickle.dump(predictions, file) + predictions_per_class = evaluate.prepare_predictions(predictions, ssd.N_CLASSES) del predictions + with open(predictions_per_class_file, "wb") as file: + pickle.dump(predictions_per_class, file) + # compute matches between predictions and ground truth true_positives, false_positives, \ cum_true_positives, cum_false_positives = evaluate.match_predictions(predictions_per_class, @@ -221,7 +231,19 @@ def _ssd_test(args: argparse.Namespace) -> None: average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES) mean_average_precision = evaluate.get_mean_average_precision(average_precisions) - # TODO store result of evaluation + results = { + "true_positives": true_positives, + "false_positives": false_positives, + "cumulative_true_positives": cum_true_positives, + "cumulative_false_positives": cum_false_positives, + "cumulative_precisions": cum_precisions, + "cumulative_recalls": cum_recalls, + "average precisions": average_precisions, + "mean_average_precision": mean_average_precision + } + + with open(result_file, "wb") as file: + pickle.dump(results, file) def _val(args: argparse.Namespace) -> None: