Added function to retrieve SSD model

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-10 15:24:49 +02:00
parent b01cda512a
commit ed3fec54af

View File

@ -152,6 +152,72 @@ class DropoutSSD:
return self.model(inputs)
def get_model(use_dropout: bool,
dropout_model: callable, vanilla_model: callable,
image_size: int, nr_classes: int, mode: str,
iou_threshold: float, dropout_rate: float, top_k: int,
pre_trained_weights_file: Optional[str] = None) -> Tuple[tf.keras.models.Model, np.ndarray]:
"""
Returns the correct SSD model and the corresponding predictor sizes.
Args:
use_dropout: True if dropout variant should be used, False otherwise
dropout_model: function to build dropout SSD model
vanilla_model: function to build vanilla SSD model
image_size: size of the resized images
nr_classes: number of classes
mode: one of "training", "inference", "inference_fast"
iou_threshold: all boxes with higher iou to local maximum box are suppressed
dropout_rate: rate for dropout layers (only applies if dropout is used)
top_k: number of highest scoring predictions kept for each batch item
pre_trained_weights_file: path to h5 file with pre-trained weights
Returns:
SSD model, predictor_sizes
"""
image_size = (image_size, image_size, 3)
scales = [0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05]
if use_dropout:
model, predictor_sizes = dropout_model(
image_size=image_size,
n_classes=nr_classes,
mode=mode,
iou_threshold=iou_threshold,
dropout_rate=dropout_rate,
top_k=top_k,
scales=scales,
return_predictor_sizes=True
)
else:
model, predictor_sizes = vanilla_model(
image_size=image_size,
n_classes=nr_classes,
mode=mode,
iou_threshold=iou_threshold,
top_k=top_k,
scales=scales,
return_predictor_sizes=True
)
if mode == "training":
# set non-classifier layers to non-trainable
classifier_names = ['conv4_3_norm_mbox_conf',
'fc7_mbox_conf',
'conv6_2_mbox_conf',
'conv7_2_mbox_conf',
'conv8_2_mbox_conf',
'conv9_2_mbox_conf']
for layer in model.layers:
if layer.name not in classifier_names:
layer.trainable = False
if pre_trained_weights_file is not None:
model.load_weights(pre_trained_weights_file, by_name=True)
return model, predictor_sizes
def predict_keras(generator: callable,
steps_per_epoch: int,
ssd_model: tf.keras.models.Model,