diff --git a/src/twomartens/masterthesis/main.py b/src/twomartens/masterthesis/main.py index 0f90e5f..ab07473 100644 --- a/src/twomartens/masterthesis/main.py +++ b/src/twomartens/masterthesis/main.py @@ -153,7 +153,41 @@ def _auto_encoder_train(args: argparse.Namespace) -> None: def _test(args: argparse.Namespace) -> None: - raise NotImplementedError + if args.network == "ssd": + _ssd_test(args) + else: + raise NotImplementedError + + +def _ssd_test(args: argparse.Namespace) -> None: + import glob + import pickle + + import numpy as np + import tensorflow as tf + + tf.enable_eager_execution() + + batch_size = 16 + use_dropout = False if args.network == "ssd" else True + output_path = f"{args.output_path}/val/{args.network}/{args.iteration}" + label_file = f"{output_path}/labels.bin" + + # retrieve labels and un-batch them + files = glob.glob(f"{output_path}/*ssd_labels*") + labels = [] + for filename in files: + with open(filename, "rb") as file: + # get labels per batch + _labels = np.load(file, allow_pickle=False, fix_imports=False) + # exclude padded label entries + real_labels = _labels[:, :, 0] != -1 + labels.extend(_labels[real_labels]) + # store labels for later use + with open(label_file, "wb") as file: + pickle.dump(labels, file) + + # TODO implement evaluate.py analogous to average_precision_evaluator def _val(args: argparse.Namespace) -> None: