Added ability limit number of trajectories considered

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-04 13:15:33 +02:00
parent d3be3ea5f9
commit 7f5a32a441

View File

@ -23,7 +23,7 @@ Functions:
load_scenenet_data(...): loads the SceneNet RGB-D data into a Tensorflow data set load_scenenet_data(...): loads the SceneNet RGB-D data into a Tensorflow data set
prepare_scenenet_data(...): prepares the SceneNet RGB-D data and returns it in Python format prepare_scenenet_data(...): prepares the SceneNet RGB-D data and returns it in Python format
""" """
from typing import Callable, List, Mapping, Tuple from typing import Callable, List, Mapping, Tuple, Optional
from typing import Sequence from typing import Sequence
import numpy as np import numpy as np
@ -239,7 +239,8 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
resized_shape: Sequence[int], resized_shape: Sequence[int],
training: bool, training: bool,
evaluation: bool, evaluation: bool,
augment: bool) -> Tuple[callable, int]: augment: bool,
nr_trajectories: Optional[int] = None) -> Tuple[callable, int]:
""" """
Loads the SceneNet RGB-D data and returns a data set. Loads the SceneNet RGB-D data and returns a data set.
@ -253,6 +254,7 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
training: True if training data is desired training: True if training data is desired
evaluation: True if evaluation-ready data is desired evaluation: True if evaluation-ready data is desired
augment: True if training data should be augmented augment: True if training data should be augmented
nr_trajectories: number of trajectories to consider
Returns: Returns:
scenenet data set generator scenenet data set generator
@ -268,7 +270,10 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
cats_to_classes, _, _, _ = coco_utils.get_coco_category_maps(annotation_file_train) cats_to_classes, _, _, _ = coco_utils.get_coco_category_maps(annotation_file_train)
max_nr_labels = -1 max_nr_labels = -1
for trajectory in trajectories: for i, trajectory in enumerate(trajectories):
if nr_trajectories is not None and i >= nr_trajectories:
break
traj_image_paths, traj_instances = trajectory traj_image_paths, traj_instances = trajectory
for image_path, frame_instances in zip(traj_image_paths, traj_instances): for image_path, frame_instances in zip(traj_image_paths, traj_instances):
labels = [] labels = []