From 2d19ccf4a5c19fc3589c814b3f9b8ab3ab6b4532 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Thu, 11 Jul 2019 12:30:08 +0200 Subject: [PATCH] Changed ssd_train to use new compile_model function Signed-off-by: Jim Martens --- src/twomartens/masterthesis/cli.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index dc9b98e..b3d432f 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -204,6 +204,8 @@ def _ssd_train(args: argparse.Namespace) -> None: dropout_rate, top_k, pre_trained_weights_file) + loss_func = ssd.get_loss_func() + ssd.compile_model(ssd_model, learning_rate, loss_func) train_generator, train_length, train_debug_generator, \ val_generator, val_length, val_debug_generator = _ssd_train_get_generators(args, @@ -235,7 +237,6 @@ def _ssd_train(args: argparse.Namespace) -> None: steps_per_val_epoch, ssd_model, weights_path, - learning_rate, tensorboard_callback ) @@ -560,7 +561,7 @@ def _ssd_train_call(args: argparse.Namespace, train_function: callable, train_generator: Generator, nr_batches_train: int, val_generator: Generator, nr_batches_val: int, model: tf.keras.models.Model, - weights_path: str, learning_rate: float, + weights_path: str, tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History: history = train_function( @@ -573,7 +574,6 @@ def _ssd_train_call(args: argparse.Namespace, train_function: callable, args.iteration, initial_epoch=0, nr_epochs=args.num_epochs, - lr=learning_rate, tensorboard_callback=tensorboard_callback )