Updated train_keras to work with SSD model directly
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -507,7 +507,7 @@ def train_keras(train_generator: callable,
|
||||
steps_per_epoch_train: int,
|
||||
val_generator: callable,
|
||||
steps_per_epoch_val: int,
|
||||
ssd_model: Union[SSD, DropoutSSD],
|
||||
ssd_model: tf.keras.models.Model,
|
||||
weights_prefix: str,
|
||||
iteration: int,
|
||||
initial_epoch: int,
|
||||
@ -522,7 +522,7 @@ def train_keras(train_generator: callable,
|
||||
steps_per_epoch_train: number of batches per training epoch
|
||||
val_generator: generator of validation data
|
||||
steps_per_epoch_val: number of batches per validation epoch
|
||||
ssd_model: wrapper of SSD model
|
||||
ssd_model: SSD model
|
||||
weights_prefix: prefix for weights directory
|
||||
iteration: identifier for current training run
|
||||
initial_epoch: the epoch to start training in
|
||||
@ -536,7 +536,7 @@ def train_keras(train_generator: callable,
|
||||
ssd_loss = keras_ssd_loss.SSDLoss()
|
||||
|
||||
# compile the model
|
||||
ssd_model.model.compile(
|
||||
ssd_model.compile(
|
||||
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||
beta1=0.9, beta2=0.999),
|
||||
loss=ssd_loss.compute_loss,
|
||||
@ -562,7 +562,7 @@ def train_keras(train_generator: callable,
|
||||
if tensorboard_callback is not None:
|
||||
callbacks.append(tensorboard_callback)
|
||||
|
||||
history = ssd_model.model.fit_generator(generator=train_generator,
|
||||
history = ssd_model.fit_generator(generator=train_generator,
|
||||
epochs=nr_epochs,
|
||||
steps_per_epoch=steps_per_epoch_train,
|
||||
validation_data=val_generator,
|
||||
@ -570,8 +570,8 @@ def train_keras(train_generator: callable,
|
||||
callbacks=callbacks,
|
||||
initial_epoch=initial_epoch)
|
||||
|
||||
ssd_model.model.save(f"{checkpoint_dir}/ssd300.h5")
|
||||
ssd_model.model.save_weights(f"{checkpoint_dir}/ssd300_weights.h5")
|
||||
ssd_model.save(f"{checkpoint_dir}/ssd300.h5")
|
||||
ssd_model.save_weights(f"{checkpoint_dir}/ssd300_weights.h5")
|
||||
|
||||
return history
|
||||
|
||||
|
||||
Reference in New Issue
Block a user