Moved public train function before the private functions
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -187,6 +187,63 @@ def predict(generator: callable,
|
|||||||
nr_digits=nr_digits))
|
nr_digits=nr_digits))
|
||||||
|
|
||||||
|
|
||||||
|
def train(train_generator: callable,
|
||||||
|
steps_per_epoch_train: int,
|
||||||
|
val_generator: callable,
|
||||||
|
steps_per_epoch_val: int,
|
||||||
|
ssd_model: tf.keras.models.Model,
|
||||||
|
weights_prefix: str,
|
||||||
|
iteration: int,
|
||||||
|
initial_epoch: int,
|
||||||
|
nr_epochs: int,
|
||||||
|
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
|
||||||
|
"""
|
||||||
|
Trains the SSD on the given data set using Keras functionality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_generator: generator of training data
|
||||||
|
steps_per_epoch_train: number of batches per training epoch
|
||||||
|
val_generator: generator of validation data
|
||||||
|
steps_per_epoch_val: number of batches per validation epoch
|
||||||
|
ssd_model: compiled SSD model
|
||||||
|
weights_prefix: prefix for weights directory
|
||||||
|
iteration: identifier for current training run
|
||||||
|
initial_epoch: the epoch to start training in
|
||||||
|
nr_epochs: number of epochs to train
|
||||||
|
tensorboard_callback: initialised TensorBoard callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
checkpoint_dir = os.path.join(weights_prefix, str(iteration))
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
callbacks = [
|
||||||
|
tf.keras.callbacks.ModelCheckpoint(
|
||||||
|
filepath=f"{checkpoint_dir}/ssd300-{{epoch:02d}}_loss-{{loss:.4f}}_val_loss-{{val_loss:.4f}}.h5",
|
||||||
|
monitor="val_loss",
|
||||||
|
verbose=1,
|
||||||
|
save_best_only=True,
|
||||||
|
save_weights_only=False
|
||||||
|
),
|
||||||
|
tf.keras.callbacks.TerminateOnNaN(),
|
||||||
|
# tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss")
|
||||||
|
]
|
||||||
|
if tensorboard_callback is not None:
|
||||||
|
callbacks.append(tensorboard_callback)
|
||||||
|
|
||||||
|
history = ssd_model.fit_generator(generator=train_generator,
|
||||||
|
epochs=nr_epochs,
|
||||||
|
steps_per_epoch=steps_per_epoch_train,
|
||||||
|
validation_data=val_generator,
|
||||||
|
validation_steps=steps_per_epoch_val,
|
||||||
|
callbacks=callbacks,
|
||||||
|
initial_epoch=initial_epoch)
|
||||||
|
|
||||||
|
ssd_model.save(f"{checkpoint_dir}/ssd300.h5")
|
||||||
|
ssd_model.save_weights(f"{checkpoint_dir}/ssd300_weights.h5")
|
||||||
|
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, str]:
|
def _predict_prepare_paths(output_path: str, use_dropout: bool) -> Tuple[str, str]:
|
||||||
filename = "ssd_predictions"
|
filename = "ssd_predictions"
|
||||||
label_filename = "ssd_labels"
|
label_filename = "ssd_labels"
|
||||||
@ -269,63 +326,6 @@ def _save_predictions(transformed_predictions: np.ndarray, original_labels: np.n
|
|||||||
pickle.dump({"labels": original_labels, "filenames": filenames}, label_file)
|
pickle.dump({"labels": original_labels, "filenames": filenames}, label_file)
|
||||||
|
|
||||||
|
|
||||||
def train(train_generator: callable,
|
|
||||||
steps_per_epoch_train: int,
|
|
||||||
val_generator: callable,
|
|
||||||
steps_per_epoch_val: int,
|
|
||||||
ssd_model: tf.keras.models.Model,
|
|
||||||
weights_prefix: str,
|
|
||||||
iteration: int,
|
|
||||||
initial_epoch: int,
|
|
||||||
nr_epochs: int,
|
|
||||||
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
|
|
||||||
"""
|
|
||||||
Trains the SSD on the given data set using Keras functionality.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_generator: generator of training data
|
|
||||||
steps_per_epoch_train: number of batches per training epoch
|
|
||||||
val_generator: generator of validation data
|
|
||||||
steps_per_epoch_val: number of batches per validation epoch
|
|
||||||
ssd_model: compiled SSD model
|
|
||||||
weights_prefix: prefix for weights directory
|
|
||||||
iteration: identifier for current training run
|
|
||||||
initial_epoch: the epoch to start training in
|
|
||||||
nr_epochs: number of epochs to train
|
|
||||||
tensorboard_callback: initialised TensorBoard callback
|
|
||||||
"""
|
|
||||||
|
|
||||||
checkpoint_dir = os.path.join(weights_prefix, str(iteration))
|
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
||||||
|
|
||||||
callbacks = [
|
|
||||||
tf.keras.callbacks.ModelCheckpoint(
|
|
||||||
filepath=f"{checkpoint_dir}/ssd300-{{epoch:02d}}_loss-{{loss:.4f}}_val_loss-{{val_loss:.4f}}.h5",
|
|
||||||
monitor="val_loss",
|
|
||||||
verbose=1,
|
|
||||||
save_best_only=True,
|
|
||||||
save_weights_only=False
|
|
||||||
),
|
|
||||||
tf.keras.callbacks.TerminateOnNaN(),
|
|
||||||
# tf.keras.callbacks.EarlyStopping(patience=2, min_delta=0.001, monitor="val_loss")
|
|
||||||
]
|
|
||||||
if tensorboard_callback is not None:
|
|
||||||
callbacks.append(tensorboard_callback)
|
|
||||||
|
|
||||||
history = ssd_model.fit_generator(generator=train_generator,
|
|
||||||
epochs=nr_epochs,
|
|
||||||
steps_per_epoch=steps_per_epoch_train,
|
|
||||||
validation_data=val_generator,
|
|
||||||
validation_steps=steps_per_epoch_val,
|
|
||||||
callbacks=callbacks,
|
|
||||||
initial_epoch=initial_epoch)
|
|
||||||
|
|
||||||
ssd_model.save(f"{checkpoint_dir}/ssd300.h5")
|
|
||||||
ssd_model.save_weights(f"{checkpoint_dir}/ssd300_weights.h5")
|
|
||||||
|
|
||||||
return history
|
|
||||||
|
|
||||||
|
|
||||||
def _get_observations(detections: Sequence[Sequence[np.ndarray]]) -> List[List[np.ndarray]]:
|
def _get_observations(detections: Sequence[Sequence[np.ndarray]]) -> List[List[np.ndarray]]:
|
||||||
batch_size = len(detections)
|
batch_size = len(detections)
|
||||||
observations = [[] for _ in range(batch_size)]
|
observations = [[] for _ in range(batch_size)]
|
||||||
|
|||||||
Reference in New Issue
Block a user