diff --git a/src/twomartens/masterthesis/main.py b/src/twomartens/masterthesis/main.py index 4f0d582..1f84c80 100644 --- a/src/twomartens/masterthesis/main.py +++ b/src/twomartens/masterthesis/main.py @@ -46,6 +46,7 @@ def main() -> None: _build_prepare(prepare_parser) _build_train(train_parser) _build_val(val_parser) + _build_test(test_parser) args = parser.parse_args() @@ -119,6 +120,24 @@ def _build_auto_encoder_val(parser: argparse.ArgumentParser) -> None: parser.add_argument("iteration_trained", type=int, help="the training iteration") +def _build_test(parser: argparse.ArgumentParser) -> None: + sub_parsers = parser.add_subparsers(dest="network") + sub_parsers.required = True + + ssd_bayesian_parser = sub_parsers.add_parser("bayesian_ssd", help="SSD with dropout layers") + ssd_parser = sub_parsers.add_parser("ssd", help="SSD") + + # build sub parsers + _build_ssd_test(ssd_bayesian_parser) + _build_ssd_test(ssd_parser) + + +def _build_ssd_test(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--output_path", type=str, help="path to the output directory") + parser.add_argument("--evaluation_path", type=str, help="path to the directory for the evaluation results") + parser.add_argument("iteration", type=int, help="the validation iteration to use") + + def _train(args: argparse.Namespace) -> None: if args.network == "auto_encoder": _auto_encoder_train(args)