Added extra functions to compile model and get loss function
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -21,6 +21,8 @@ Attributes:
|
|||||||
N_CLASSES: number of known classes (without background)
|
N_CLASSES: number of known classes (without background)
|
||||||
|
|
||||||
Functions:
|
Functions:
|
||||||
|
compile_model(...): compiles an SSD model
|
||||||
|
get_loss_func(...): returns the SSD loss function
|
||||||
get_model(...): returns correct SSD model and corresponding predictor sizes
|
get_model(...): returns correct SSD model and corresponding predictor sizes
|
||||||
predict(...): runs trained SSD/DropoutSSD on a given data set
|
predict(...): runs trained SSD/DropoutSSD on a given data set
|
||||||
train(...): trains the SSD/DropoutSSD on a given data set
|
train(...): trains the SSD/DropoutSSD on a given data set
|
||||||
@ -113,6 +115,35 @@ def get_model(use_dropout: bool,
|
|||||||
return model, predictor_sizes
|
return model, predictor_sizes
|
||||||
|
|
||||||
|
|
||||||
|
def get_loss_func() -> callable:
|
||||||
|
return keras_ssd_loss.SSDLoss().compute_loss
|
||||||
|
|
||||||
|
|
||||||
|
def compile_model(model: tf.keras.models.Model, learning_rate: float, loss_func: callable) -> None:
|
||||||
|
"""
|
||||||
|
Compiles an SSD model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SSD model
|
||||||
|
learning_rate: the learning rate
|
||||||
|
loss_func: loss function to minimize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
learning_rate_var = K.variable(learning_rate)
|
||||||
|
|
||||||
|
# compile the model
|
||||||
|
model.compile(
|
||||||
|
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||||
|
beta1=0.9, beta2=0.999),
|
||||||
|
loss=loss_func,
|
||||||
|
metrics=[
|
||||||
|
"categorical_accuracy"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def predict(generator: callable,
|
def predict(generator: callable,
|
||||||
steps_per_epoch: int,
|
steps_per_epoch: int,
|
||||||
ssd_model: tf.keras.models.Model,
|
ssd_model: tf.keras.models.Model,
|
||||||
@ -245,7 +276,6 @@ def train(train_generator: callable,
|
|||||||
iteration: int,
|
iteration: int,
|
||||||
initial_epoch: int,
|
initial_epoch: int,
|
||||||
nr_epochs: int,
|
nr_epochs: int,
|
||||||
lr: float,
|
|
||||||
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
|
tensorboard_callback: Optional[tf.keras.callbacks.TensorBoard]) -> tf.keras.callbacks.History:
|
||||||
"""
|
"""
|
||||||
Trains the SSD on the given data set using Keras functionality.
|
Trains the SSD on the given data set using Keras functionality.
|
||||||
@ -255,29 +285,14 @@ def train(train_generator: callable,
|
|||||||
steps_per_epoch_train: number of batches per training epoch
|
steps_per_epoch_train: number of batches per training epoch
|
||||||
val_generator: generator of validation data
|
val_generator: generator of validation data
|
||||||
steps_per_epoch_val: number of batches per validation epoch
|
steps_per_epoch_val: number of batches per validation epoch
|
||||||
ssd_model: SSD model
|
ssd_model: compiled SSD model
|
||||||
weights_prefix: prefix for weights directory
|
weights_prefix: prefix for weights directory
|
||||||
iteration: identifier for current training run
|
iteration: identifier for current training run
|
||||||
initial_epoch: the epoch to start training in
|
initial_epoch: the epoch to start training in
|
||||||
nr_epochs: number of epochs to train
|
nr_epochs: number of epochs to train
|
||||||
lr: initial learning rate
|
|
||||||
tensorboard_callback: initialised TensorBoard callback
|
tensorboard_callback: initialised TensorBoard callback
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# set up variables
|
|
||||||
learning_rate_var = K.variable(lr)
|
|
||||||
ssd_loss = keras_ssd_loss.SSDLoss()
|
|
||||||
|
|
||||||
# compile the model
|
|
||||||
ssd_model.compile(
|
|
||||||
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate_var,
|
|
||||||
beta1=0.9, beta2=0.999),
|
|
||||||
loss=ssd_loss.compute_loss,
|
|
||||||
metrics=[
|
|
||||||
"categorical_accuracy"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
checkpoint_dir = os.path.join(weights_prefix, str(iteration))
|
checkpoint_dir = os.path.join(weights_prefix, str(iteration))
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user