Ensure training only affects classifier layers
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -80,6 +80,18 @@ class SSD:
|
|||||||
# load existing weights
|
# load existing weights
|
||||||
if weights_path is not None:
|
if weights_path is not None:
|
||||||
self._model.load_weights(weights_path, by_name=True)
|
self._model.load_weights(weights_path, by_name=True)
|
||||||
|
|
||||||
|
if mode == "training":
|
||||||
|
# set non-classifier layers to non-trainable
|
||||||
|
classifier_names = ['conv4_3_norm_mbox_conf',
|
||||||
|
'fc7_mbox_conf',
|
||||||
|
'conv6_2_mbox_conf',
|
||||||
|
'conv7_2_mbox_conf',
|
||||||
|
'conv8_2_mbox_conf',
|
||||||
|
'conv9_2_mbox_conf']
|
||||||
|
for layer in self._model.layers:
|
||||||
|
if layer.name not in classifier_names:
|
||||||
|
layer.trainable = False
|
||||||
|
|
||||||
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
||||||
return self._model(inputs)
|
return self._model(inputs)
|
||||||
@ -104,6 +116,18 @@ class DropoutSSD:
|
|||||||
# load existing weights
|
# load existing weights
|
||||||
if weights_path is not None:
|
if weights_path is not None:
|
||||||
self._model.load_weights(weights_path, by_name=True)
|
self._model.load_weights(weights_path, by_name=True)
|
||||||
|
|
||||||
|
if mode == "training":
|
||||||
|
# set non-classifier layers to non-trainable
|
||||||
|
classifier_names = ['conv4_3_norm_mbox_conf',
|
||||||
|
'fc7_mbox_conf',
|
||||||
|
'conv6_2_mbox_conf',
|
||||||
|
'conv7_2_mbox_conf',
|
||||||
|
'conv8_2_mbox_conf',
|
||||||
|
'conv9_2_mbox_conf']
|
||||||
|
for layer in self._model.layers:
|
||||||
|
if layer.name not in classifier_names:
|
||||||
|
layer.trainable = False
|
||||||
|
|
||||||
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
|
||||||
return self._model(inputs)
|
return self._model(inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user