From 91b7febabb08e5317ea7b3be92c17fcd2764d151 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Wed, 10 Jul 2019 12:15:17 +0200 Subject: [PATCH] Changed train function to be an integration Signed-off-by: Jim Martens --- src/twomartens/masterthesis/cli.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index 84c0876..b124586 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -58,10 +58,7 @@ def prepare(args: argparse.Namespace) -> None: def train(args: argparse.Namespace) -> None: - if args.network == "ssd" or args.network == "bayesian_ssd": - _ssd_train(args) - elif args.network == "auto_encoder": - _auto_encoder_train(args) + _train_execute_action(args, _ssd_train, _auto_encoder_train) def test(args: argparse.Namespace) -> None: @@ -164,6 +161,13 @@ def _config_execute_action(args: argparse.Namespace, on_get: callable, on_list() +def _train_execute_action(args: argparse.Namespace, on_ssd: callable, on_auto_encoder: callable) -> None: + if args.network == "ssd" or args.network == "bayesian_ssd": + on_ssd(args) + elif args.network == "auto_encoder": + on_auto_encoder(args) + + def _ssd_train(args: argparse.Namespace) -> None: import os import pickle