diff --git a/src/twomartens/masterthesis/data.py b/src/twomartens/masterthesis/data.py index 78613c3..331bb24 100644 --- a/src/twomartens/masterthesis/data.py +++ b/src/twomartens/masterthesis/data.py @@ -23,7 +23,6 @@ Functions: load_scenenet_data(...): loads the SceneNet RGB-D data into a Tensorflow data set prepare_scenenet_data(...): prepares the SceneNet RGB-D data and returns it in Python format """ -import functools from typing import Callable, List, Mapping, Tuple from typing import Sequence @@ -304,16 +303,22 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], width=resized_shape[1]) ] - generator = functools.partial( - data_generator.generate, + generator = data_generator.generate( batch_size=batch_size, shuffle=shuffle, transformations=transformations, label_encoder=ssd_input_encoder.SSDInputEncoder( - img_height=resized_shape[0], - img_width=resized_shape[1], - n_classes=len(cats_to_classes), # 80 - predictor_sizes=predictor_sizes + img_height=resized_shape[0], + img_width=resized_shape[1], + n_classes=len(cats_to_classes), # 80 + predictor_sizes=predictor_sizes, + steps=[8, 16, 32, 64, 100, 300], + aspect_ratios_per_layer=[[1.0, 2.0, 0.5], + [1.0, 2.0, 0.5, 3.0, 1.0 / 3.0], + [1.0, 2.0, 0.5, 3.0, 1.0 / 3.0], + [1.0, 2.0, 0.5, 3.0, 1.0 / 3.0], + [1.0, 2.0, 0.5], + [1.0, 2.0, 0.5]] ), returns={'processed_images', 'encoded_labels'}, keep_images_without_gt=False