Disable augmentation of input for now
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -75,13 +75,13 @@ def _ssd_train(args: argparse.Namespace) -> None:
|
|||||||
predictor_sizes=ssd_model.predictor_sizes,
|
predictor_sizes=ssd_model.predictor_sizes,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
resized_shape=(image_size, image_size),
|
resized_shape=(image_size, image_size),
|
||||||
training=True, evaluation=False)
|
training=True, evaluation=False, augment=False)
|
||||||
val_generator, val_length = \
|
val_generator, val_length = \
|
||||||
data.load_scenenet_data(file_names_val, instances_val, args.coco_path,
|
data.load_scenenet_data(file_names_val, instances_val, args.coco_path,
|
||||||
predictor_sizes=ssd_model.predictor_sizes,
|
predictor_sizes=ssd_model.predictor_sizes,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
resized_shape=(image_size, image_size),
|
resized_shape=(image_size, image_size),
|
||||||
training=False, evaluation=False)
|
training=False, evaluation=False, augment=False)
|
||||||
del file_names_train, instances_train, file_names_val, instances_val
|
del file_names_train, instances_train, file_names_val, instances_val
|
||||||
|
|
||||||
if args.debug:
|
if args.debug:
|
||||||
|
|||||||
@ -238,7 +238,8 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
resized_shape: Sequence[int],
|
resized_shape: Sequence[int],
|
||||||
training: bool,
|
training: bool,
|
||||||
evaluation: bool) -> Tuple[callable, int]:
|
evaluation: bool,
|
||||||
|
augment: bool) -> Tuple[callable, int]:
|
||||||
"""
|
"""
|
||||||
Loads the SceneNet RGB-D data and returns a data set.
|
Loads the SceneNet RGB-D data and returns a data set.
|
||||||
|
|
||||||
@ -251,6 +252,7 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
|
|||||||
resized_shape: shape of input images to SSD
|
resized_shape: shape of input images to SSD
|
||||||
training: True if training data is desired
|
training: True if training data is desired
|
||||||
evaluation: True if evaluation-ready data is desired
|
evaluation: True if evaluation-ready data is desired
|
||||||
|
augment: True if training data should be augmented
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
scenenet data set generator
|
scenenet data set generator
|
||||||
@ -296,14 +298,14 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
|
|||||||
labels=final_labels
|
labels=final_labels
|
||||||
)
|
)
|
||||||
|
|
||||||
if training:
|
shuffle = True if training else False
|
||||||
shuffle = True
|
|
||||||
|
if training and augment:
|
||||||
transformations = [data_augmentation_chain_original_ssd.SSDDataAugmentation(
|
transformations = [data_augmentation_chain_original_ssd.SSDDataAugmentation(
|
||||||
img_width=resized_shape[0],
|
img_width=resized_shape[0],
|
||||||
img_height=resized_shape[1]
|
img_height=resized_shape[1]
|
||||||
)]
|
)]
|
||||||
else:
|
else:
|
||||||
shuffle = False
|
|
||||||
transformations = [
|
transformations = [
|
||||||
object_detection_2d_photometric_ops.ConvertTo3Channels(),
|
object_detection_2d_photometric_ops.ConvertTo3Channels(),
|
||||||
object_detection_2d_geometric_ops.Resize(height=resized_shape[0],
|
object_detection_2d_geometric_ops.Resize(height=resized_shape[0],
|
||||||
|
|||||||
Reference in New Issue
Block a user