Changed ssd_train to use new compile_model function
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -204,6 +204,8 @@ def _ssd_train(args: argparse.Namespace) -> None:
|
|||||||
dropout_rate,
|
dropout_rate,
|
||||||
top_k,
|
top_k,
|
||||||
pre_trained_weights_file)
|
pre_trained_weights_file)
|
||||||
|
loss_func = ssd.get_loss_func()
|
||||||
|
ssd.compile_model(ssd_model, learning_rate, loss_func)
|
||||||
|
|
||||||
train_generator, train_length, train_debug_generator, \
|
train_generator, train_length, train_debug_generator, \
|
||||||
val_generator, val_length, val_debug_generator = _ssd_train_get_generators(args,
|
val_generator, val_length, val_debug_generator = _ssd_train_get_generators(args,
|
||||||
@ -235,7 +237,6 @@ def _ssd_train(args: argparse.Namespace) -> None:
|
|||||||
steps_per_val_epoch,
|
steps_per_val_epoch,
|
||||||
ssd_model,
|
ssd_model,
|
||||||
weights_path,
|
weights_path,
|
||||||
learning_rate,
|
|
||||||
tensorboard_callback
|
tensorboard_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -560,7 +561,7 @@ def _ssd_train_call(args: argparse.Namespace, train_function: callable,
|
|||||||
train_generator: Generator, nr_batches_train: int,
|
train_generator: Generator, nr_batches_train: int,
|
||||||
val_generator: Generator, nr_batches_val: int,
|
val_generator: Generator, nr_batches_val: int,
|
||||||
model: tf.keras.models.Model,
|
model: tf.keras.models.Model,
|
||||||
weights_path: str, learning_rate: float,
|
weights_path: str,
|
||||||
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
|
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
|
||||||
|
|
||||||
history = train_function(
|
history = train_function(
|
||||||
@ -573,7 +574,6 @@ def _ssd_train_call(args: argparse.Namespace, train_function: callable,
|
|||||||
args.iteration,
|
args.iteration,
|
||||||
initial_epoch=0,
|
initial_epoch=0,
|
||||||
nr_epochs=args.num_epochs,
|
nr_epochs=args.num_epochs,
|
||||||
lr=learning_rate,
|
|
||||||
tensorboard_callback=tensorboard_callback
|
tensorboard_callback=tensorboard_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user