diff --git a/src/twomartens/masterthesis/data.py b/src/twomartens/masterthesis/data.py index 05e42dd..de59c08 100644 --- a/src/twomartens/masterthesis/data.py +++ b/src/twomartens/masterthesis/data.py @@ -23,7 +23,7 @@ Functions: load_scenenet_data(...): loads the SceneNet RGB-D data into a Tensorflow data set prepare_scenenet_data(...): prepares the SceneNet RGB-D data and returns it in Python format """ -from typing import Callable, List, Mapping, Tuple +from typing import Callable, List, Mapping, Tuple, Optional from typing import Sequence import numpy as np @@ -239,7 +239,8 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], resized_shape: Sequence[int], training: bool, evaluation: bool, - augment: bool) -> Tuple[callable, int]: + augment: bool, + nr_trajectories: Optional[int] = None) -> Tuple[callable, int]: """ Loads the SceneNet RGB-D data and returns a data set. @@ -253,6 +254,7 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], training: True if training data is desired evaluation: True if evaluation-ready data is desired augment: True if training data should be augmented + nr_trajectories: number of trajectories to consider Returns: scenenet data set generator @@ -268,7 +270,10 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]], cats_to_classes, _, _, _ = coco_utils.get_coco_category_maps(annotation_file_train) max_nr_labels = -1 - for trajectory in trajectories: + for i, trajectory in enumerate(trajectories): + if nr_trajectories is not None and i >= nr_trajectories: + break + traj_image_paths, traj_instances = trajectory for image_path, frame_instances in zip(traj_image_paths, traj_instances): labels = []