Changed ssd_train to use new compile_model function

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-11 12:30:08 +02:00
parent 06548dc8e5
commit 2d19ccf4a5

View File

@ -204,6 +204,8 @@ def _ssd_train(args: argparse.Namespace) -> None:
dropout_rate, dropout_rate,
top_k, top_k,
pre_trained_weights_file) pre_trained_weights_file)
loss_func = ssd.get_loss_func()
ssd.compile_model(ssd_model, learning_rate, loss_func)
train_generator, train_length, train_debug_generator, \ train_generator, train_length, train_debug_generator, \
val_generator, val_length, val_debug_generator = _ssd_train_get_generators(args, val_generator, val_length, val_debug_generator = _ssd_train_get_generators(args,
@ -235,7 +237,6 @@ def _ssd_train(args: argparse.Namespace) -> None:
steps_per_val_epoch, steps_per_val_epoch,
ssd_model, ssd_model,
weights_path, weights_path,
learning_rate,
tensorboard_callback tensorboard_callback
) )
@ -560,7 +561,7 @@ def _ssd_train_call(args: argparse.Namespace, train_function: callable,
train_generator: Generator, nr_batches_train: int, train_generator: Generator, nr_batches_train: int,
val_generator: Generator, nr_batches_val: int, val_generator: Generator, nr_batches_val: int,
model: tf.keras.models.Model, model: tf.keras.models.Model,
weights_path: str, learning_rate: float, weights_path: str,
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History: tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
history = train_function( history = train_function(
@ -573,7 +574,6 @@ def _ssd_train_call(args: argparse.Namespace, train_function: callable,
args.iteration, args.iteration,
initial_epoch=0, initial_epoch=0,
nr_epochs=args.num_epochs, nr_epochs=args.num_epochs,
lr=learning_rate,
tensorboard_callback=tensorboard_callback tensorboard_callback=tensorboard_callback
) )