Changed train function to be an integration
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
parent
a90c6c7c92
commit
91b7febabb
|
@ -58,10 +58,7 @@ def prepare(args: argparse.Namespace) -> None:
|
|||
|
||||
|
||||
def train(args: argparse.Namespace) -> None:
|
||||
if args.network == "ssd" or args.network == "bayesian_ssd":
|
||||
_ssd_train(args)
|
||||
elif args.network == "auto_encoder":
|
||||
_auto_encoder_train(args)
|
||||
_train_execute_action(args, _ssd_train, _auto_encoder_train)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace) -> None:
|
||||
|
@ -164,6 +161,13 @@ def _config_execute_action(args: argparse.Namespace, on_get: callable,
|
|||
on_list()
|
||||
|
||||
|
||||
def _train_execute_action(args: argparse.Namespace, on_ssd: callable, on_auto_encoder: callable) -> None:
|
||||
if args.network == "ssd" or args.network == "bayesian_ssd":
|
||||
on_ssd(args)
|
||||
elif args.network == "auto_encoder":
|
||||
on_auto_encoder(args)
|
||||
|
||||
|
||||
def _ssd_train(args: argparse.Namespace) -> None:
|
||||
import os
|
||||
import pickle
|
||||
|
|
Loading…
Reference in New Issue