Modified data module to cut everything but the bounding box

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-09 10:18:50 +02:00
parent 57da44c6d8
commit be501d606f

View File

@ -21,6 +21,7 @@ Functions:
load_coco(...): loads the COCO data into a Tensorflow data set load_coco(...): loads the COCO data into a Tensorflow data set
load_scenenet(...): loads the SceneNet RGB-D data into a Tensorflow data set load_scenenet(...): loads the SceneNet RGB-D data into a Tensorflow data set
""" """
from typing import Sequence
from typing import Tuple from typing import Tuple
import tensorflow as tf import tensorflow as tf
@ -49,7 +50,7 @@ def load_coco(data_path: str, category: int,
annotation_ids = coco_train.getAnnIds(img_ids) annotation_ids = coco_train.getAnnIds(img_ids)
annotations = coco_train.loadAnns(annotation_ids) # load all image annotations annotations = coco_train.loadAnns(annotation_ids) # load all image annotations
file_names = [f"train2014/{image['file_name']}" for image in images] file_names = [f"train2014/{image['file_name']}" for image in images]
cat_ids = [annotation['category_id'] for annotation in annotations] bboxes = [annotation['bbox'] for annotation in annotations]
# load validation images # load validation images
coco_val = coco.COCO(annotation_file_val) coco_val = coco.COCO(annotation_file_val)
@ -58,23 +59,25 @@ def load_coco(data_path: str, category: int,
annotation_ids = coco_val.getAnnIds(img_ids) annotation_ids = coco_val.getAnnIds(img_ids)
annotations = coco_val.loadAnns(annotation_ids) # load all image annotations annotations = coco_val.loadAnns(annotation_ids) # load all image annotations
file_names_val = [f"val2014/{image['file_name']}" for image in images] file_names_val = [f"val2014/{image['file_name']}" for image in images]
cat_ids_val = [annotation['category_id'] for annotation in annotations] bboxes_val = [annotation['bbox'] for annotation in annotations]
file_names.extend(file_names_val) file_names.extend(file_names_val)
cat_ids.extend(cat_ids_val) bboxes.extend(bboxes_val)
length_dataset = len(file_names) length_dataset = len(file_names)
def _load_image(image_data: Tuple[str, int]): def _load_image(image_data: Tuple[str, Sequence[int]]):
path, label = image_data path, label = image_data
image = tf.read_file(f"{data_path}/images/{path}") image = tf.read_file(f"{data_path}/images/{path}")
image = tf.image.decode_image(image, channels=3) image = tf.image.decode_image(image, channels=3)
x1, x2, y1, y2 = label[0], label[0] + label[2], label[1], label[1] + label[3]
image_cut = image[x1:x2 + 1, y1:y2 + 1]
return image, label return image_cut, label
# build image data set # build image data set
path_dataset = tf.data.Dataset.from_tensor_slices(file_names) path_dataset = tf.data.Dataset.from_tensor_slices(file_names)
label_dataset = tf.data.Dataset.from_tensor_slices(cat_ids) label_dataset = tf.data.Dataset.from_tensor_slices(bboxes)
dataset = tf.data.Dataset.zip((path_dataset, label_dataset)) dataset = tf.data.Dataset.zip((path_dataset, label_dataset))
dataset = dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=length_dataset, count=num_epochs)) dataset = dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=length_dataset, count=num_epochs))
dataset = dataset.batch(batch_size=batch_size) dataset = dataset.batch(batch_size=batch_size)