From 6c431e5e707dbc0c650b2ffd2d8d72005acab6d2 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Thu, 11 Jul 2019 14:55:20 +0200 Subject: [PATCH] Moved public train function before the private functions Signed-off-by: Jim Martens --- src/twomartens/masterthesis/ssd.py | 114 ++++++++++++++--------------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index f6e3b03..f191a11 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -187,6 +187,63 @@ def predict(generator: callable, nr_digits=nr_digits)) +def train(train_generator: callable, + steps_per_epoch_train: int, + val_generator: callable, + steps_per_epoch_val: int, + ssd_model: tf.keras.models.Model, + weights_prefix: str, + iteration: int, + initial_epoch: int, + nr_epochs: int, + tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History: + """ + Trains the SSD on the given data set using Keras functionality. + + Args: + train_generator: generator of training data + steps_per_epoch_train: number of batches per training epoch + val_generator: generator of validation data + steps_per_epoch_val: number of batches per validation epoch + ssd_model: compiled SSD model + weights_prefix: prefix for weights directory + iteration: identifier for current training run + initial_epoch: the epoch to start training in + nr_epochs: number of epochs to train + tensorboard_callback: initialised TensorBoard callback + """ + + checkpoint_dir = os.path.join(weights_prefix, str(iteration)) + os.makedirs(checkpoint_dir, exist_ok=True) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint( + filepath=f"{checkpoint_dir}/ssd300-{{epoch:02d}}_loss-{{loss:.4f}}_val_loss-{{val_loss:.4f}}.h5", + monitor="val_loss", + verbose=1, + save_best_only=True, + save_weights_only=False + ), + tf.keras.callbacks.TerminateOnNaN(), + # tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss") + ] + if tensorboard_callback is not None: + callbacks.append(tensorboard_callback) + + history = ssd_model.fit_generator(generator=train_generator, + epochs=nr_epochs, + steps_per_epoch=steps_per_epoch_train, + validation_data=val_generator, + validation_steps=steps_per_epoch_val, + callbacks=callbacks, + initial_epoch=initial_epoch) + + ssd_model.save(f"{checkpoint_dir}/ssd300.h5") + ssd_model.save_weights(f"{checkpoint_dir}/ssd300_weights.h5") + + return history + + def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, str]: filename = "ssd_predictions" label_filename = "ssd_labels" @@ -269,63 +326,6 @@ def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.n pickle.dump({"labels": original_labels, "filenames": filenames}, label_file) -def train(train_generator: callable, - steps_per_epoch_train: int, - val_generator: callable, - steps_per_epoch_val: int, - ssd_model: tf.keras.models.Model, - weights_prefix: str, - iteration: int, - initial_epoch: int, - nr_epochs: int, - tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History: - """ - Trains the SSD on the given data set using Keras functionality. - - Args: - train_generator: generator of training data - steps_per_epoch_train: number of batches per training epoch - val_generator: generator of validation data - steps_per_epoch_val: number of batches per validation epoch - ssd_model: compiled SSD model - weights_prefix: prefix for weights directory - iteration: identifier for current training run - initial_epoch: the epoch to start training in - nr_epochs: number of epochs to train - tensorboard_callback: initialised TensorBoard callback - """ - - checkpoint_dir = os.path.join(weights_prefix, str(iteration)) - os.makedirs(checkpoint_dir, exist_ok=True) - - callbacks = [ - tf.keras.callbacks.ModelCheckpoint( - filepath=f"{checkpoint_dir}/ssd300-{{epoch:02d}}_loss-{{loss:.4f}}_val_loss-{{val_loss:.4f}}.h5", - monitor="val_loss", - verbose=1, - save_best_only=True, - save_weights_only=False - ), - tf.keras.callbacks.TerminateOnNaN(), - # tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss") - ] - if tensorboard_callback is not None: - callbacks.append(tensorboard_callback) - - history = ssd_model.fit_generator(generator=train_generator, - epochs=nr_epochs, - steps_per_epoch=steps_per_epoch_train, - validation_data=val_generator, - validation_steps=steps_per_epoch_val, - callbacks=callbacks, - initial_epoch=initial_epoch) - - ssd_model.save(f"{checkpoint_dir}/ssd300.h5") - ssd_model.save_weights(f"{checkpoint_dir}/ssd300_weights.h5") - - return history - - def _get_observations(detections: Sequence[Sequence[np.ndarray]]) -> List[List[np.ndarray]]: batch_size = len(detections) observations = [[] for _ in range(batch_size)]