Updated train_keras to work with SSD model directly
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -507,7 +507,7 @@ def train_keras(train_generator: callable,
|
|||||||
steps_per_epoch_train: int,
|
steps_per_epoch_train: int,
|
||||||
val_generator: callable,
|
val_generator: callable,
|
||||||
steps_per_epoch_val: int,
|
steps_per_epoch_val: int,
|
||||||
ssd_model: Union[SSD, DropoutSSD],
|
ssd_model: tf.keras.models.Model,
|
||||||
weights_prefix: str,
|
weights_prefix: str,
|
||||||
iteration: int,
|
iteration: int,
|
||||||
initial_epoch: int,
|
initial_epoch: int,
|
||||||
@ -522,7 +522,7 @@ def train_keras(train_generator: callable,
|
|||||||
steps_per_epoch_train: number of batches per training epoch
|
steps_per_epoch_train: number of batches per training epoch
|
||||||
val_generator: generator of validation data
|
val_generator: generator of validation data
|
||||||
steps_per_epoch_val: number of batches per validation epoch
|
steps_per_epoch_val: number of batches per validation epoch
|
||||||
ssd_model: wrapper of SSD model
|
ssd_model: SSD model
|
||||||
weights_prefix: prefix for weights directory
|
weights_prefix: prefix for weights directory
|
||||||
iteration: identifier for current training run
|
iteration: identifier for current training run
|
||||||
initial_epoch: the epoch to start training in
|
initial_epoch: the epoch to start training in
|
||||||
@ -536,7 +536,7 @@ def train_keras(train_generator: callable,
|
|||||||
ssd_loss = keras_ssd_loss.SSDLoss()
|
ssd_loss = keras_ssd_loss.SSDLoss()
|
||||||
|
|
||||||
# compile the model
|
# compile the model
|
||||||
ssd_model.model.compile(
|
ssd_model.compile(
|
||||||
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate_var,
|
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||||
beta1=0.9, beta2=0.999),
|
beta1=0.9, beta2=0.999),
|
||||||
loss=ssd_loss.compute_loss,
|
loss=ssd_loss.compute_loss,
|
||||||
@ -562,7 +562,7 @@ def train_keras(train_generator: callable,
|
|||||||
if tensorboard_callback is not None:
|
if tensorboard_callback is not None:
|
||||||
callbacks.append(tensorboard_callback)
|
callbacks.append(tensorboard_callback)
|
||||||
|
|
||||||
history = ssd_model.model.fit_generator(generator=train_generator,
|
history = ssd_model.fit_generator(generator=train_generator,
|
||||||
epochs=nr_epochs,
|
epochs=nr_epochs,
|
||||||
steps_per_epoch=steps_per_epoch_train,
|
steps_per_epoch=steps_per_epoch_train,
|
||||||
validation_data=val_generator,
|
validation_data=val_generator,
|
||||||
@ -570,8 +570,8 @@ def train_keras(train_generator: callable,
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
initial_epoch=initial_epoch)
|
initial_epoch=initial_epoch)
|
||||||
|
|
||||||
ssd_model.model.save(f"{checkpoint_dir}/ssd300.h5")
|
ssd_model.save(f"{checkpoint_dir}/ssd300.h5")
|
||||||
ssd_model.model.save_weights(f"{checkpoint_dir}/ssd300_weights.h5")
|
ssd_model.save_weights(f"{checkpoint_dir}/ssd300_weights.h5")
|
||||||
|
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user