Added train function which utilises the Keras train functions
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -359,6 +359,81 @@ def _get_observations(detections: Sequence[Sequence[np.ndarray]]) -> List[List[n
|
||||
return observations
|
||||
|
||||
|
||||
def train_keras(train_generator: callable,
|
||||
steps_per_epoch_train: int,
|
||||
val_generator: callable,
|
||||
steps_per_epoch_val: int,
|
||||
ssd_model: Union[SSD, DropoutSSD],
|
||||
weights_prefix: str,
|
||||
iteration: int,
|
||||
initial_epoch: int,
|
||||
nr_epochs: int,
|
||||
lr: float,
|
||||
tensorboard_callback: tf.keras.callbacks.TensorBoard) -> tf.keras.callbacks.History:
|
||||
"""
|
||||
Trains the SSD on the given data set using Keras functionality.
|
||||
|
||||
Args:
|
||||
train_generator: generator of training data
|
||||
steps_per_epoch_train: number of batches per training epoch
|
||||
val_generator: generator of validation data
|
||||
steps_per_epoch_val: number of batches per validation epoch
|
||||
ssd_model: wrapper of SSD model
|
||||
weights_prefix: prefix for weights directory
|
||||
iteration: identifier for current training run
|
||||
initial_epoch: the epoch to start training in
|
||||
nr_epochs: number of epochs to train
|
||||
lr: initial learning rate
|
||||
tensorboard_callback: initialised TensorBoard callback
|
||||
"""
|
||||
|
||||
# set up variables
|
||||
learning_rate_var = K.variable(lr)
|
||||
ssd_loss = keras_ssd_loss.SSDLoss()
|
||||
|
||||
# compile the model
|
||||
ssd_model.model.compile(
|
||||
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||
beta1=0.5, beta2=0.999),
|
||||
loss=ssd_loss.compute_loss,
|
||||
metrics=[
|
||||
tf.keras.metrics.Precision(),
|
||||
tf.keras.metrics.Recall(),
|
||||
tf.keras.metrics.FalsePositives(),
|
||||
tf.keras.metrics.CategoricalAccuracy()
|
||||
]
|
||||
)
|
||||
|
||||
checkpoint_dir = os.path.join(weights_prefix, str(iteration), "/")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
callbacks = [
|
||||
tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=f"{checkpoint_dir}ssd300-{{epoch:02d}}_loss-{{loss:.4f}}_val_loss-{{val_loss:.4f}}.h5",
|
||||
monitor="val_loss",
|
||||
verbose=1,
|
||||
save_best_only=True,
|
||||
save_weights_only=False
|
||||
),
|
||||
tf.keras.callbacks.TerminateOnNaN(),
|
||||
tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss"),
|
||||
tensorboard_callback
|
||||
]
|
||||
|
||||
history = ssd_model.model.fit_generator(generator=train_generator,
|
||||
epochs=nr_epochs,
|
||||
steps_per_epoch=steps_per_epoch_train,
|
||||
validation_data=val_generator,
|
||||
validation_steps=steps_per_epoch_val,
|
||||
callbacks=callbacks,
|
||||
initial_epoch=initial_epoch)
|
||||
|
||||
ssd_model.model.save(f"{checkpoint_dir}ssd300.h5")
|
||||
ssd_model.model.save_weights(f"{checkpoint_dir}ssd300_weights.h5")
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def train(dataset: tf.data.Dataset,
|
||||
iteration: int,
|
||||
use_dropout: bool,
|
||||
|
||||
Reference in New Issue
Block a user