Ensure training only affects classifier layers
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -81,6 +81,18 @@ class SSD:
|
||||
if weights_path is not None:
|
||||
self._model.load_weights(weights_path, by_name=True)
|
||||
|
||||
if mode == "training":
|
||||
# set non-classifier layers to non-trainable
|
||||
classifier_names = ['conv4_3_norm_mbox_conf',
|
||||
'fc7_mbox_conf',
|
||||
'conv6_2_mbox_conf',
|
||||
'conv7_2_mbox_conf',
|
||||
'conv8_2_mbox_conf',
|
||||
'conv9_2_mbox_conf']
|
||||
for layer in self._model.layers:
|
||||
if layer.name not in classifier_names:
|
||||
layer.trainable = False
|
||||
|
||||
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
||||
return self._model(inputs)
|
||||
|
||||
@ -105,6 +117,18 @@ class DropoutSSD:
|
||||
if weights_path is not None:
|
||||
self._model.load_weights(weights_path, by_name=True)
|
||||
|
||||
if mode == "training":
|
||||
# set non-classifier layers to non-trainable
|
||||
classifier_names = ['conv4_3_norm_mbox_conf',
|
||||
'fc7_mbox_conf',
|
||||
'conv6_2_mbox_conf',
|
||||
'conv7_2_mbox_conf',
|
||||
'conv8_2_mbox_conf',
|
||||
'conv9_2_mbox_conf']
|
||||
for layer in self._model.layers:
|
||||
if layer.name not in classifier_names:
|
||||
layer.trainable = False
|
||||
|
||||
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
||||
return self._model(inputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user