diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index f63f5c3..3b62ce9 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -359,6 +359,81 @@ def _get_observations(detections: Sequence[Sequence[np.ndarray]]) -> List[List[n return observations +def train_keras(train_generator: callable, + steps_per_epoch_train: int, + val_generator: callable, + steps_per_epoch_val: int, + ssd_model: Union[SSD, DropoutSSD], + weights_prefix: str, + iteration: int, + initial_epoch: int, + nr_epochs: int, + lr: float, + tensorboard_callback: tf.keras.callbacks.TensorBoard) -> tf.keras.callbacks.History: + """ + Trains the SSD on the given data set using Keras functionality. + + Args: + train_generator: generator of training data + steps_per_epoch_train: number of batches per training epoch + val_generator: generator of validation data + steps_per_epoch_val: number of batches per validation epoch + ssd_model: wrapper of SSD model + weights_prefix: prefix for weights directory + iteration: identifier for current training run + initial_epoch: the epoch to start training in + nr_epochs: number of epochs to train + lr: initial learning rate + tensorboard_callback: initialised TensorBoard callback + """ + + # set up variables + learning_rate_var = K.variable(lr) + ssd_loss = keras_ssd_loss.SSDLoss() + + # compile the model + ssd_model.model.compile( + optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate_var, + beta1=0.5, beta2=0.999), + loss=ssd_loss.compute_loss, + metrics=[ + tf.keras.metrics.Precision(), + tf.keras.metrics.Recall(), + tf.keras.metrics.FalsePositives(), + tf.keras.metrics.CategoricalAccuracy() + ] + ) + + checkpoint_dir = os.path.join(weights_prefix, str(iteration), "/") + os.makedirs(checkpoint_dir, exist_ok=True) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint( + filepath=f"{checkpoint_dir}ssd300-{{epoch:02d}}_loss-{{loss:.4f}}_val_loss-{{val_loss:.4f}}.h5", + monitor="val_loss", + verbose=1, + save_best_only=True, + save_weights_only=False + ), + tf.keras.callbacks.TerminateOnNaN(), + tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss"), + tensorboard_callback + ] + + history = ssd_model.model.fit_generator(generator=train_generator, + epochs=nr_epochs, + steps_per_epoch=steps_per_epoch_train, + validation_data=val_generator, + validation_steps=steps_per_epoch_val, + callbacks=callbacks, + initial_epoch=initial_epoch) + + ssd_model.model.save(f"{checkpoint_dir}ssd300.h5") + ssd_model.model.save_weights(f"{checkpoint_dir}ssd300_weights.h5") + + return history + + def train(dataset: tf.data.Dataset, iteration: int, use_dropout: bool,