diff --git a/src/twomartens/masterthesis/data.py b/src/twomartens/masterthesis/data.py index 331bb24..140e53b 100644 --- a/src/twomartens/masterthesis/data.py +++ b/src/twomartens/masterthesis/data.py @@ -302,6 +302,11 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], object_detection_2d_geometric_ops.Resize(height=resized_shape[0], width=resized_shape[1]) ] + + returns = {'processed_images', 'encoded_labels'} + + if not training: + returns.update({'inverse_transform'}) generator = data_generator.generate( batch_size=batch_size, @@ -320,7 +325,7 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], [1.0, 2.0, 0.5], [1.0, 2.0, 0.5]] ), - returns={'processed_images', 'encoded_labels'}, + returns=returns, keep_images_without_gt=False )