Ensure training only affects classifier layers
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -80,6 +80,18 @@ class SSD:
|
||||
# load existing weights
|
||||
if weights_path is not None:
|
||||
self._model.load_weights(weights_path, by_name=True)
|
||||
|
||||
if mode == "training":
|
||||
# set non-classifier layers to non-trainable
|
||||
classifier_names = ['conv4_3_norm_mbox_conf',
|
||||
'fc7_mbox_conf',
|
||||
'conv6_2_mbox_conf',
|
||||
'conv7_2_mbox_conf',
|
||||
'conv8_2_mbox_conf',
|
||||
'conv9_2_mbox_conf']
|
||||
for layer in self._model.layers:
|
||||
if layer.name not in classifier_names:
|
||||
layer.trainable = False
|
||||
|
||||
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
||||
return self._model(inputs)
|
||||
@ -104,6 +116,18 @@ class DropoutSSD:
|
||||
# load existing weights
|
||||
if weights_path is not None:
|
||||
self._model.load_weights(weights_path, by_name=True)
|
||||
|
||||
if mode == "training":
|
||||
# set non-classifier layers to non-trainable
|
||||
classifier_names = ['conv4_3_norm_mbox_conf',
|
||||
'fc7_mbox_conf',
|
||||
'conv6_2_mbox_conf',
|
||||
'conv7_2_mbox_conf',
|
||||
'conv8_2_mbox_conf',
|
||||
'conv9_2_mbox_conf']
|
||||
for layer in self._model.layers:
|
||||
if layer.name not in classifier_names:
|
||||
layer.trainable = False
|
||||
|
||||
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
||||
return self._model(inputs)
|
||||
|
||||
Reference in New Issue
Block a user