Modified data loader to supply the full trainval35k data

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-04 19:04:19 +02:00
parent a8ef3d000f
commit 57da44c6d8

View File

@ -27,28 +27,42 @@ import tensorflow as tf
from pycocotools import coco
def load_coco(data_path: str, data_type: str, category: int,
def load_coco(data_path: str, category: int,
num_epochs: int, batch_size: int = 32) -> tf.data.Dataset:
"""
Loads the COCO data and returns a data set.
Loads the COCO trainval35k data and returns a data set.
Args:
data_path: path to the COCO data set
data_type: type of the COCO data (e.g. 'val2014')
category: id of the inlying class
num_epochs: number of epochs
batch_size: batch size (default: 32)
Returns:
Tensorflow data set
"""
annotation_file = f"{data_path}/annotations/instances_{data_type}.json"
coco_interface = coco.COCO(annotation_file)
img_ids = coco_interface.getImgIds(catIds=[category]) # return all image IDs belonging to given category
images = coco_interface.loadImgs(img_ids) # load all images
annotation_ids = coco_interface.getAnnIds(img_ids)
annotations = coco_interface.loadAnns(annotation_ids) # load all image annotations
file_names = [image['file_name'] for image in images]
annotation_file_train = f"{data_path}/annotations/instances_train2014.json"
annotation_file_val = f"{data_path}/annotations/instances_valminusminival2014.json"
# load training images
coco_train = coco.COCO(annotation_file_train)
img_ids = coco_train.getImgIds(catIds=[category]) # return all image IDs belonging to given category
images = coco_train.loadImgs(img_ids) # load all images
annotation_ids = coco_train.getAnnIds(img_ids)
annotations = coco_train.loadAnns(annotation_ids) # load all image annotations
file_names = [f"train2014/{image['file_name']}" for image in images]
cat_ids = [annotation['category_id'] for annotation in annotations]
# load validation images
coco_val = coco.COCO(annotation_file_val)
img_ids = coco_val.getImgIds(catIds=[category]) # return all image IDs belonging to given category
images = coco_val.loadImgs(img_ids) # load all images
annotation_ids = coco_val.getAnnIds(img_ids)
annotations = coco_val.loadAnns(annotation_ids) # load all image annotations
file_names_val = [f"val2014/{image['file_name']}" for image in images]
cat_ids_val = [annotation['category_id'] for annotation in annotations]
file_names.extend(file_names_val)
cat_ids.extend(cat_ids_val)
length_dataset = len(file_names)
def _load_image(image_data: Tuple[str, int]):