Added data loader for COCO data set
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
66
src/twomartens/masterthesis/data.py
Normal file
66
src/twomartens/masterthesis/data.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright 2019 Jim Martens
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Functionality to load COCO data into Tensorflow data sets.
|
||||||
|
|
||||||
|
Functions:
|
||||||
|
load_coco(...): loads the COCO data into a Tensorflow data set
|
||||||
|
"""
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from pycocotools import coco
|
||||||
|
|
||||||
|
|
||||||
|
def load_coco(data_path: str, data_type: str, num_epochs: int, batch_size: int = 32) -> tf.data.Dataset:
|
||||||
|
"""
|
||||||
|
Loads the COCO data and returns a data set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path: path to the COCO data set
|
||||||
|
data_type: type of the COCO data (e.g. 'val2014')
|
||||||
|
num_epochs: number of epochs
|
||||||
|
batch_size: batch size
|
||||||
|
Returns:
|
||||||
|
Tensorflow data set
|
||||||
|
"""
|
||||||
|
annotation_file = f"{data_path}/annotations/instances_{data_type}.json"
|
||||||
|
coco_interface = coco.COCO(annotation_file)
|
||||||
|
img_ids = coco_interface.getImgIds() # return all image IDs
|
||||||
|
images = coco_interface.loadImgs(img_ids) # load all images
|
||||||
|
annotation_ids = coco_interface.getAnnIds(img_ids)
|
||||||
|
annotations = coco_interface.loadAnns(annotation_ids) # load all image annotations
|
||||||
|
file_names = [image['file_name'] for image in images]
|
||||||
|
cat_ids = [annotation['category_id'] for annotation in annotations]
|
||||||
|
length_dataset = len(file_names)
|
||||||
|
|
||||||
|
def _load_image(image_data: Tuple[str, int]):
|
||||||
|
path, label = image_data
|
||||||
|
image = tf.read_file(f"{data_path}/images/{path}")
|
||||||
|
image = tf.image.decode_image(image, channels=3)
|
||||||
|
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
# build image data set
|
||||||
|
path_dataset = tf.data.Dataset.from_tensor_slices(file_names)
|
||||||
|
label_dataset = tf.data.Dataset.from_tensor_slices(cat_ids)
|
||||||
|
dataset = tf.data.Dataset.zip((path_dataset, label_dataset))
|
||||||
|
dataset = dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=length_dataset, count=num_epochs))
|
||||||
|
dataset = dataset.batch(batch_size=batch_size)
|
||||||
|
dataset = dataset.map(_load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||||
|
|
||||||
|
return dataset
|
||||||
Reference in New Issue
Block a user