diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index 84c0876..b124586 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -58,10 +58,7 @@ def prepare(args: argparse.Namespace) -> None: def train(args: argparse.Namespace) -> None: - if args.network == "ssd" or args.network == "bayesian_ssd": - _ssd_train(args) - elif args.network == "auto_encoder": - _auto_encoder_train(args) + _train_execute_action(args, _ssd_train, _auto_encoder_train) def test(args: argparse.Namespace) -> None: @@ -164,6 +161,13 @@ def _config_execute_action(args: argparse.Namespace, on_get: callable, on_list() +def _train_execute_action(args: argparse.Namespace, on_ssd: callable, on_auto_encoder: callable) -> None: + if args.network == "ssd" or args.network == "bayesian_ssd": + on_ssd(args) + elif args.network == "auto_encoder": + on_auto_encoder(args) + + def _ssd_train(args: argparse.Namespace) -> None: import os import pickle