Improved data generation to cover evaluation cases as well

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-06-13 16:51:59 +02:00
parent f8bed423e4
commit 77a195a144

View File

@ -236,7 +236,8 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
coco_path: str, predictor_sizes: np.ndarray, coco_path: str, predictor_sizes: np.ndarray,
batch_size: int, batch_size: int,
resized_shape: Sequence[int], resized_shape: Sequence[int],
training: bool) -> Tuple[callable, int]: training: bool,
evaluation: bool) -> Tuple[callable, int]:
""" """
Loads the SceneNet RGB-D data and returns a data set. Loads the SceneNet RGB-D data and returns a data set.
@ -248,6 +249,7 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
batch_size: size of every batch batch_size: size of every batch
resized_shape: shape of input images to SSD resized_shape: shape of input images to SSD
training: True if training data is desired training: True if training data is desired
evaluation: True if evaluation-ready data is desired
Returns: Returns:
scenenet data set generator scenenet data set generator
@ -305,14 +307,15 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
returns = {'processed_images', 'encoded_labels'} returns = {'processed_images', 'encoded_labels'}
if not training: if not training and evaluation:
returns.update({'inverse_transform'}) returns = {
'processed_images',
generator = data_generator.generate( 'filenames',
batch_size=batch_size, 'inverse_transform',
shuffle=shuffle, 'original_labels'}
transformations=transformations, label_encoder = None
label_encoder=ssd_input_encoder.SSDInputEncoder( else:
label_encoder = ssd_input_encoder.SSDInputEncoder(
img_height=resized_shape[0], img_height=resized_shape[0],
img_width=resized_shape[1], img_width=resized_shape[1],
n_classes=len(cats_to_classes), # 80 n_classes=len(cats_to_classes), # 80
@ -324,7 +327,13 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
[1.0, 2.0, 0.5, 3.0, 1.0 / 3.0], [1.0, 2.0, 0.5, 3.0, 1.0 / 3.0],
[1.0, 2.0, 0.5], [1.0, 2.0, 0.5],
[1.0, 2.0, 0.5]] [1.0, 2.0, 0.5]]
), )
generator = data_generator.generate(
batch_size=batch_size,
shuffle=shuffle,
transformations=transformations,
label_encoder=label_encoder,
returns=returns, returns=returns,
keep_images_without_gt=False keep_images_without_gt=False
) )