Fixed data loading

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-10 15:17:24 +02:00
parent f81ecbac0d
commit da6e348edc

View File

@ -21,15 +21,16 @@ Functions:
load_coco(...): loads the COCO data into a Tensorflow data set
load_scenenet(...): loads the SceneNet RGB-D data into a Tensorflow data set
"""
from typing import List
from typing import Sequence
from typing import Tuple
import tensorflow as tf
from pycocotools import coco
def load_coco(data_path: str, category: int,
num_epochs: int, batch_size: int = 32) -> tf.data.Dataset:
num_epochs: int, batch_size: int = 32,
resized_shape: Sequence[int] = (256, 256)) -> tf.data.Dataset:
"""
Loads the COCO trainval35k data and returns a data set.
@ -38,6 +39,7 @@ def load_coco(data_path: str, category: int,
category: id of the inlying class
num_epochs: number of epochs
batch_size: batch size (default: 32)
resized_shape: shape of images after resizing them (default: (300,300))
Returns:
Tensorflow data set
"""
@ -49,7 +51,7 @@ def load_coco(data_path: str, category: int,
images = coco_train.loadImgs(img_ids) # load all images
annotation_ids = coco_train.getAnnIds(img_ids)
annotations = coco_train.loadAnns(annotation_ids) # load all image annotations
file_names = [f"train2014/{image['file_name']}" for image in images]
file_names = [f"{data_path}/images/train2014/{image['file_name']}" for image in images]
bboxes = [annotation['bbox'] for annotation in annotations]
# load validation images
@ -58,7 +60,7 @@ def load_coco(data_path: str, category: int,
images = coco_val.loadImgs(img_ids) # load all images
annotation_ids = coco_val.getAnnIds(img_ids)
annotations = coco_val.loadAnns(annotation_ids) # load all image annotations
file_names_val = [f"val2014/{image['file_name']}" for image in images]
file_names_val = [f"{data_path}/images/val2014/{image['file_name']}" for image in images]
bboxes_val = [annotation['bbox'] for annotation in annotations]
file_names.extend(file_names_val)
@ -66,14 +68,28 @@ def load_coco(data_path: str, category: int,
length_dataset = len(file_names)
def _load_image(image_data: Tuple[str, Sequence[int]]):
path, label = image_data
image = tf.read_file(f"{data_path}/images/{path}")
image = tf.image.decode_image(image, channels=3)
x1, x2, y1, y2 = label[0], label[0] + label[2], label[1], label[1] + label[3]
image_cut = image[x1:x2 + 1, y1:y2 + 1]
def _load_image(paths: Sequence[str], labels: Sequence[Sequence[float]]):
_images = tf.map_fn(lambda path: tf.read_file(path), paths)
return image_cut, label
def _get_images(image_data: Sequence[tf.Tensor]) -> List[tf.Tensor]:
image = tf.image.decode_image(image_data[0], channels=3, dtype=tf.float32)
image_shape = tf.shape(image)
image = tf.reshape(image, [image_shape[0], image_shape[1], 3])
label = image_data[1]
x1, x2, y1, y2 = tf.cast(tf.round(label[0]), dtype=tf.int32), \
tf.cast(tf.round(label[0] + label[2]), dtype=tf.int32), \
tf.cast(tf.round(label[1]), dtype=tf.int32), \
tf.cast(tf.round(label[1] + label[3]), dtype=tf.int32)
# image_cut = image[x1:x2 + 1, y1:y2 + 1]
image_resized = tf.image.resize_image_with_pad(image, resized_shape[0], resized_shape[1])
return [image_resized, label]
processed = tf.map_fn(_get_images, [_images, labels], dtype=[tf.float32, tf.float32])
processed_images = processed[0]
processed_images = tf.reshape(processed_images, [-1, resized_shape[0], resized_shape[1], 3])
return processed_images
# build image data set
path_dataset = tf.data.Dataset.from_tensor_slices(file_names)
@ -98,5 +114,3 @@ def load_scenenet(data_path: str, num_epochs: int, batch_size: int = 32) -> tf.d
Tensorflow data set
"""
pass