diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index 79a9061..59ce0f1 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -75,13 +75,13 @@ def _ssd_train(args: argparse.Namespace) -> None: predictor_sizes=ssd_model.predictor_sizes, batch_size=batch_size, resized_shape=(image_size, image_size), - training=True, evaluation=False) + training=True, evaluation=False, augment=False) val_generator, val_length = \ data.load_scenenet_data(file_names_val, instances_val, args.coco_path, predictor_sizes=ssd_model.predictor_sizes, batch_size=batch_size, resized_shape=(image_size, image_size), - training=False, evaluation=False) + training=False, evaluation=False, augment=False) del file_names_train, instances_train, file_names_val, instances_val if args.debug: diff --git a/src/twomartens/masterthesis/data.py b/src/twomartens/masterthesis/data.py index 5cac8eb..05e42dd 100644 --- a/src/twomartens/masterthesis/data.py +++ b/src/twomartens/masterthesis/data.py @@ -238,7 +238,8 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], batch_size: int, resized_shape: Sequence[int], training: bool, - evaluation: bool) -> Tuple[callable, int]: + evaluation: bool, + augment: bool) -> Tuple[callable, int]: """ Loads the SceneNet RGB-D data and returns a data set. @@ -251,6 +252,7 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], resized_shape: shape of input images to SSD training: True if training data is desired evaluation: True if evaluation-ready data is desired + augment: True if training data should be augmented Returns: scenenet data set generator @@ -296,14 +298,14 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], labels=final_labels ) - if training: - shuffle = True + shuffle = True if training else False + + if training and augment: transformations = [data_augmentation_chain_original_ssd.SSDDataAugmentation( img_width=resized_shape[0], img_height=resized_shape[1] )] else: - shuffle = False transformations = [ object_detection_2d_photometric_ops.ConvertTo3Channels(), object_detection_2d_geometric_ops.Resize(height=resized_shape[0],