From 77a195a144996925da13384c97a55b3b9d7080a9 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Thu, 13 Jun 2019 16:51:59 +0200 Subject: [PATCH] Improved data generation to cover evaluation cases as well Signed-off-by: Jim Martens --- src/twomartens/masterthesis/data.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/twomartens/masterthesis/data.py b/src/twomartens/masterthesis/data.py index 140e53b..4e5dae3 100644 --- a/src/twomartens/masterthesis/data.py +++ b/src/twomartens/masterthesis/data.py @@ -236,7 +236,8 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], coco_path: str, predictor_sizes: np.ndarray, batch_size: int, resized_shape: Sequence[int], - training: bool) -> Tuple[callable, int]: + training: bool, + evaluation: bool) -> Tuple[callable, int]: """ Loads the SceneNet RGB-D data and returns a data set. @@ -248,6 +249,7 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], batch_size: size of every batch resized_shape: shape of input images to SSD training: True if training data is desired + evaluation: True if evaluation-ready data is desired Returns: scenenet data set generator @@ -305,14 +307,15 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], returns = {'processed_images', 'encoded_labels'} - if not training: - returns.update({'inverse_transform'}) - - generator = data_generator.generate( - batch_size=batch_size, - shuffle=shuffle, - transformations=transformations, - label_encoder=ssd_input_encoder.SSDInputEncoder( + if not training and evaluation: + returns = { + 'processed_images', + 'filenames', + 'inverse_transform', + 'original_labels'} + label_encoder = None + else: + label_encoder = ssd_input_encoder.SSDInputEncoder( img_height=resized_shape[0], img_width=resized_shape[1], n_classes=len(cats_to_classes), # 80 @@ -324,7 +327,13 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], [1.0, 2.0, 0.5, 3.0, 1.0 / 3.0], [1.0, 2.0, 0.5], [1.0, 2.0, 0.5]] - ), + ) + + generator = data_generator.generate( + batch_size=batch_size, + shuffle=shuffle, + transformations=transformations, + label_encoder=label_encoder, returns=returns, keep_images_without_gt=False )