Changed ssd_train to use new compile_model function
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -204,6 +204,8 @@ def _ssd_train(args: argparse.Namespace) -> None:
|
||||
dropout_rate,
|
||||
top_k,
|
||||
pre_trained_weights_file)
|
||||
loss_func = ssd.get_loss_func()
|
||||
ssd.compile_model(ssd_model, learning_rate, loss_func)
|
||||
|
||||
train_generator, train_length, train_debug_generator, \
|
||||
val_generator, val_length, val_debug_generator = _ssd_train_get_generators(args,
|
||||
@ -235,7 +237,6 @@ def _ssd_train(args: argparse.Namespace) -> None:
|
||||
steps_per_val_epoch,
|
||||
ssd_model,
|
||||
weights_path,
|
||||
learning_rate,
|
||||
tensorboard_callback
|
||||
)
|
||||
|
||||
@ -560,7 +561,7 @@ def _ssd_train_call(args: argparse.Namespace, train_function: callable,
|
||||
train_generator: Generator, nr_batches_train: int,
|
||||
val_generator: Generator, nr_batches_val: int,
|
||||
model: tf.keras.models.Model,
|
||||
weights_path: str, learning_rate: float,
|
||||
weights_path: str,
|
||||
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
|
||||
|
||||
history = train_function(
|
||||
@ -573,7 +574,6 @@ def _ssd_train_call(args: argparse.Namespace, train_function: callable,
|
||||
args.iteration,
|
||||
initial_epoch=0,
|
||||
nr_epochs=args.num_epochs,
|
||||
lr=learning_rate,
|
||||
tensorboard_callback=tensorboard_callback
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user