From 9a7a906acf12667cb1b9c6ffdb44b514abd0178c Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Wed, 20 Mar 2019 14:54:28 +0100 Subject: [PATCH] Added ability to load trained weights Signed-off-by: Jim Martens --- src/twomartens/masterthesis/ssd.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 2e5caf2..4013e16 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -29,6 +29,8 @@ Classes: ``SSD``: wraps vanilla SSD 300 model """ +from typing import Optional + import tensorflow as tf from twomartens.masterthesis.ssd_keras.models import keras_ssd300 @@ -47,11 +49,16 @@ class SSD: Args: mode: one of training, inference, and inference_fast + weights_path: path to trained weights """ - def __init__(self, mode: str) -> None: + def __init__(self, mode: str, weights_path: Optional[str] = None) -> None: self._model = keras_ssd300.ssd_300(image_size=IMAGE_SIZE, n_classes=N_CLASSES, mode=mode) self.mode = mode + + # load existing weights + if weights_path is not None: + self._model.load_weights(weights_path, by_name=True) def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor: return self._model(inputs) @@ -63,12 +70,17 @@ class DropoutSSD: Args: mode: one of training, inference, and inference_fast + weights_path: path to trained weights """ - def __init__(self, mode: str) -> None: + def __init__(self, mode: str, weights_path: Optional[str] = None) -> None: self._model = keras_ssd300_dropout.ssd_300_dropout(image_size=IMAGE_SIZE, n_classes=N_CLASSES, dropout_rate=DROPOUT_RATE, mode=mode) self.mode = mode + + # load existing weights + if weights_path is not None: + self._model.load_weights(weights_path, by_name=True) def __call__(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor: return self._model(inputs)