Added ability limit number of trajectories considered
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -23,7 +23,7 @@ Functions:
|
|||||||
load_scenenet_data(...): loads the SceneNet RGB-D data into a Tensorflow data set
|
load_scenenet_data(...): loads the SceneNet RGB-D data into a Tensorflow data set
|
||||||
prepare_scenenet_data(...): prepares the SceneNet RGB-D data and returns it in Python format
|
prepare_scenenet_data(...): prepares the SceneNet RGB-D data and returns it in Python format
|
||||||
"""
|
"""
|
||||||
from typing import Callable, List, Mapping, Tuple
|
from typing import Callable, List, Mapping, Tuple, Optional
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -239,7 +239,8 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
|
|||||||
resized_shape: Sequence[int],
|
resized_shape: Sequence[int],
|
||||||
training: bool,
|
training: bool,
|
||||||
evaluation: bool,
|
evaluation: bool,
|
||||||
augment: bool) -> Tuple[callable, int]:
|
augment: bool,
|
||||||
|
nr_trajectories: Optional[int] = None) -> Tuple[callable, int]:
|
||||||
"""
|
"""
|
||||||
Loads the SceneNet RGB-D data and returns a data set.
|
Loads the SceneNet RGB-D data and returns a data set.
|
||||||
|
|
||||||
@ -253,6 +254,7 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
|
|||||||
training: True if training data is desired
|
training: True if training data is desired
|
||||||
evaluation: True if evaluation-ready data is desired
|
evaluation: True if evaluation-ready data is desired
|
||||||
augment: True if training data should be augmented
|
augment: True if training data should be augmented
|
||||||
|
nr_trajectories: number of trajectories to consider
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
scenenet data set generator
|
scenenet data set generator
|
||||||
@ -268,7 +270,10 @@ def load_scenenet_data(photo_paths: Sequence[Sequence[str]],
|
|||||||
cats_to_classes, _, _, _ = coco_utils.get_coco_category_maps(annotation_file_train)
|
cats_to_classes, _, _, _ = coco_utils.get_coco_category_maps(annotation_file_train)
|
||||||
max_nr_labels = -1
|
max_nr_labels = -1
|
||||||
|
|
||||||
for trajectory in trajectories:
|
for i, trajectory in enumerate(trajectories):
|
||||||
|
if nr_trajectories is not None and i >= nr_trajectories:
|
||||||
|
break
|
||||||
|
|
||||||
traj_image_paths, traj_instances = trajectory
|
traj_image_paths, traj_instances = trajectory
|
||||||
for image_path, frame_instances in zip(traj_image_paths, traj_instances):
|
for image_path, frame_instances in zip(traj_image_paths, traj_instances):
|
||||||
labels = []
|
labels = []
|
||||||
|
|||||||
Reference in New Issue
Block a user