diff --git a/src/twomartens/masterthesis/main.py b/src/twomartens/masterthesis/main.py index 256589d..922efd7 100644 --- a/src/twomartens/masterthesis/main.py +++ b/src/twomartens/masterthesis/main.py @@ -22,6 +22,7 @@ Functions: """ import argparse + def main() -> None: """ Provides command line interface. @@ -163,8 +164,10 @@ def _ssd_test(args: argparse.Namespace) -> None: import glob import pickle - import numpy as np import tensorflow as tf + + from twomartens.masterthesis import evaluate + from twomartens.masterthesis import ssd tf.enable_eager_execution() @@ -186,9 +189,40 @@ def _ssd_test(args: argparse.Namespace) -> None: # store labels for later use with open(label_file, "wb") as file: pickle.dump(labels, file) - - # TODO implement evaluate.py analogous to average_precision_evaluator + number_gt_per_class = evaluate.get_number_gt_per_class(labels, ssd.N_CLASSES) + + # retrieve predictions and un-batch them + files = glob.glob(f"{output_path}/*ssd_predictions*") + predictions = [] + for filename in files: + with open(filename, "rb") as file: + # get predictions per batch + _predictions = pickle.load(file) + # select only forward pass + _predictions = _predictions[0] + predictions.extend(_predictions) + del _predictions + + # prepare predictions for further use + predictions_per_class = evaluate.prepare_predictions(predictions, ssd.N_CLASSES) + del predictions + + # compute matches between predictions and ground truth + true_positives, false_positives, \ + cum_true_positives, cum_false_positives = evaluate.match_predictions(predictions_per_class, + labels, + ssd.N_CLASSES) + del labels + cum_precisions, cum_recalls = evaluate.get_precision_recall(number_gt_per_class, + cum_true_positives, + cum_false_positives, + ssd.N_CLASSES) + average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES) + mean_average_precision = evaluate.get_mean_average_precision(average_precisions) + + # TODO store result of evaluation + def _val(args: argparse.Namespace) -> None: if args.network == "ssd" or args.network == "bayesian_ssd":