Fixed data loading

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-10 15:17:24 +02:00
parent f81ecbac0d
commit da6e348edc

View File

@ -21,15 +21,16 @@ Functions:
load_coco(...): loads the COCO data into a Tensorflow data set load_coco(...): loads the COCO data into a Tensorflow data set
load_scenenet(...): loads the SceneNet RGB-D data into a Tensorflow data set load_scenenet(...): loads the SceneNet RGB-D data into a Tensorflow data set
""" """
from typing import List
from typing import Sequence from typing import Sequence
from typing import Tuple
import tensorflow as tf import tensorflow as tf
from pycocotools import coco from pycocotools import coco
def load_coco(data_path: str, category: int, def load_coco(data_path: str, category: int,
num_epochs: int, batch_size: int = 32) -> tf.data.Dataset: num_epochs: int, batch_size: int = 32,
resized_shape: Sequence[int] = (256, 256)) -> tf.data.Dataset:
""" """
Loads the COCO trainval35k data and returns a data set. Loads the COCO trainval35k data and returns a data set.
@ -38,6 +39,7 @@ def load_coco(data_path: str, category: int,
category: id of the inlying class category: id of the inlying class
num_epochs: number of epochs num_epochs: number of epochs
batch_size: batch size (default: 32) batch_size: batch size (default: 32)
resized_shape: shape of images after resizing them (default: (300,300))
Returns: Returns:
Tensorflow data set Tensorflow data set
""" """
@ -49,7 +51,7 @@ def load_coco(data_path: str, category: int,
images = coco_train.loadImgs(img_ids) # load all images images = coco_train.loadImgs(img_ids) # load all images
annotation_ids = coco_train.getAnnIds(img_ids) annotation_ids = coco_train.getAnnIds(img_ids)
annotations = coco_train.loadAnns(annotation_ids) # load all image annotations annotations = coco_train.loadAnns(annotation_ids) # load all image annotations
file_names = [f"train2014/{image['file_name']}" for image in images] file_names = [f"{data_path}/images/train2014/{image['file_name']}" for image in images]
bboxes = [annotation['bbox'] for annotation in annotations] bboxes = [annotation['bbox'] for annotation in annotations]
# load validation images # load validation images
@ -58,7 +60,7 @@ def load_coco(data_path: str, category: int,
images = coco_val.loadImgs(img_ids) # load all images images = coco_val.loadImgs(img_ids) # load all images
annotation_ids = coco_val.getAnnIds(img_ids) annotation_ids = coco_val.getAnnIds(img_ids)
annotations = coco_val.loadAnns(annotation_ids) # load all image annotations annotations = coco_val.loadAnns(annotation_ids) # load all image annotations
file_names_val = [f"val2014/{image['file_name']}" for image in images] file_names_val = [f"{data_path}/images/val2014/{image['file_name']}" for image in images]
bboxes_val = [annotation['bbox'] for annotation in annotations] bboxes_val = [annotation['bbox'] for annotation in annotations]
file_names.extend(file_names_val) file_names.extend(file_names_val)
@ -66,14 +68,28 @@ def load_coco(data_path: str, category: int,
length_dataset = len(file_names) length_dataset = len(file_names)
def _load_image(image_data: Tuple[str, Sequence[int]]): def _load_image(paths: Sequence[str], labels: Sequence[Sequence[float]]):
path, label = image_data _images = tf.map_fn(lambda path: tf.read_file(path), paths)
image = tf.read_file(f"{data_path}/images/{path}")
image = tf.image.decode_image(image, channels=3)
x1, x2, y1, y2 = label[0], label[0] + label[2], label[1], label[1] + label[3]
image_cut = image[x1:x2 + 1, y1:y2 + 1]
return image_cut, label def _get_images(image_data: Sequence[tf.Tensor]) -> List[tf.Tensor]:
image = tf.image.decode_image(image_data[0], channels=3, dtype=tf.float32)
image_shape = tf.shape(image)
image = tf.reshape(image, [image_shape[0], image_shape[1], 3])
label = image_data[1]
x1, x2, y1, y2 = tf.cast(tf.round(label[0]), dtype=tf.int32), \
tf.cast(tf.round(label[0] + label[2]), dtype=tf.int32), \
tf.cast(tf.round(label[1]), dtype=tf.int32), \
tf.cast(tf.round(label[1] + label[3]), dtype=tf.int32)
# image_cut = image[x1:x2 + 1, y1:y2 + 1]
image_resized = tf.image.resize_image_with_pad(image, resized_shape[0], resized_shape[1])
return [image_resized, label]
processed = tf.map_fn(_get_images, [_images, labels], dtype=[tf.float32, tf.float32])
processed_images = processed[0]
processed_images = tf.reshape(processed_images, [-1, resized_shape[0], resized_shape[1], 3])
return processed_images
# build image data set # build image data set
path_dataset = tf.data.Dataset.from_tensor_slices(file_names) path_dataset = tf.data.Dataset.from_tensor_slices(file_names)
@ -98,5 +114,3 @@ def load_scenenet(data_path: str, num_epochs: int, batch_size: int = 32) -> tf.d
Tensorflow data set Tensorflow data set
""" """
pass pass