Ensure training only affects classifier layers

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-06-05 11:31:22 +02:00
parent 5e7b16402b
commit 28abe6ae47

View File

@ -80,6 +80,18 @@ class SSD:
# load existing weights # load existing weights
if weights_path is not None: if weights_path is not None:
self._model.load_weights(weights_path, by_name=True) self._model.load_weights(weights_path, by_name=True)
if mode == "training":
# set non-classifier layers to non-trainable
classifier_names = ['conv4_3_norm_mbox_conf',
'fc7_mbox_conf',
'conv6_2_mbox_conf',
'conv7_2_mbox_conf',
'conv8_2_mbox_conf',
'conv9_2_mbox_conf']
for layer in self._model.layers:
if layer.name not in classifier_names:
layer.trainable = False
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor: def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
return self._model(inputs) return self._model(inputs)
@ -104,6 +116,18 @@ class DropoutSSD:
# load existing weights # load existing weights
if weights_path is not None: if weights_path is not None:
self._model.load_weights(weights_path, by_name=True) self._model.load_weights(weights_path, by_name=True)
if mode == "training":
# set non-classifier layers to non-trainable
classifier_names = ['conv4_3_norm_mbox_conf',
'fc7_mbox_conf',
'conv6_2_mbox_conf',
'conv7_2_mbox_conf',
'conv8_2_mbox_conf',
'conv9_2_mbox_conf']
for layer in self._model.layers:
if layer.name not in classifier_names:
layer.trainable = False
def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor: def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
return self._model(inputs) return self._model(inputs)