Modified data module to cut everything but the bounding box
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -21,6 +21,7 @@ Functions:
|
||||
load_coco(...): loads the COCO data into a Tensorflow data set
|
||||
load_scenenet(...): loads the SceneNet RGB-D data into a Tensorflow data set
|
||||
"""
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
@ -49,7 +50,7 @@ def load_coco(data_path: str, category: int,
|
||||
annotation_ids = coco_train.getAnnIds(img_ids)
|
||||
annotations = coco_train.loadAnns(annotation_ids) # load all image annotations
|
||||
file_names = [f"train2014/{image['file_name']}" for image in images]
|
||||
cat_ids = [annotation['category_id'] for annotation in annotations]
|
||||
bboxes = [annotation['bbox'] for annotation in annotations]
|
||||
|
||||
# load validation images
|
||||
coco_val = coco.COCO(annotation_file_val)
|
||||
@ -58,23 +59,25 @@ def load_coco(data_path: str, category: int,
|
||||
annotation_ids = coco_val.getAnnIds(img_ids)
|
||||
annotations = coco_val.loadAnns(annotation_ids) # load all image annotations
|
||||
file_names_val = [f"val2014/{image['file_name']}" for image in images]
|
||||
cat_ids_val = [annotation['category_id'] for annotation in annotations]
|
||||
bboxes_val = [annotation['bbox'] for annotation in annotations]
|
||||
|
||||
file_names.extend(file_names_val)
|
||||
cat_ids.extend(cat_ids_val)
|
||||
bboxes.extend(bboxes_val)
|
||||
|
||||
length_dataset = len(file_names)
|
||||
|
||||
def _load_image(image_data: Tuple[str, int]):
|
||||
def _load_image(image_data: Tuple[str, Sequence[int]]):
|
||||
path, label = image_data
|
||||
image = tf.read_file(f"{data_path}/images/{path}")
|
||||
image = tf.image.decode_image(image, channels=3)
|
||||
x1, x2, y1, y2 = label[0], label[0] + label[2], label[1], label[1] + label[3]
|
||||
image_cut = image[x1:x2 + 1, y1:y2 + 1]
|
||||
|
||||
return image, label
|
||||
return image_cut, label
|
||||
|
||||
# build image data set
|
||||
path_dataset = tf.data.Dataset.from_tensor_slices(file_names)
|
||||
label_dataset = tf.data.Dataset.from_tensor_slices(cat_ids)
|
||||
label_dataset = tf.data.Dataset.from_tensor_slices(bboxes)
|
||||
dataset = tf.data.Dataset.zip((path_dataset, label_dataset))
|
||||
dataset = dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=length_dataset, count=num_epochs))
|
||||
dataset = dataset.batch(batch_size=batch_size)
|
||||
|
||||
Reference in New Issue
Block a user