diff --git a/src/twomartens/masterthesis/ssd.py b/src/twomartens/masterthesis/ssd.py index 4f202fa..7f11d6b 100644 --- a/src/twomartens/masterthesis/ssd.py +++ b/src/twomartens/masterthesis/ssd.py @@ -403,7 +403,8 @@ def train(dataset: tf.data.Dataset, # input encoder input_encoder = ssd_input_encoder.SSDInputEncoder(IMAGE_SIZE[0], IMAGE_SIZE[1], - N_CLASSES, ssd.predictor_sizes) + N_CLASSES, ssd.predictor_sizes, + steps=[8, 16, 32, 64, 100, 300]) def _get_last_epoch(epoch_var: tf.Variable, **kwargs) -> int: return int(epoch_var)