Added argument to specify the inlying class

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-04-04 17:35:11 +02:00
parent d499c55ab4
commit a8ef3d000f

View File

@ -27,13 +27,15 @@ import tensorflow as tf
from pycocotools import coco from pycocotools import coco
def load_coco(data_path: str, data_type: str, num_epochs: int, batch_size: int = 32) -> tf.data.Dataset: def load_coco(data_path: str, data_type: str, category: int,
num_epochs: int, batch_size: int = 32) -> tf.data.Dataset:
""" """
Loads the COCO data and returns a data set. Loads the COCO data and returns a data set.
Args: Args:
data_path: path to the COCO data set data_path: path to the COCO data set
data_type: type of the COCO data (e.g. 'val2014') data_type: type of the COCO data (e.g. 'val2014')
category: id of the inlying class
num_epochs: number of epochs num_epochs: number of epochs
batch_size: batch size (default: 32) batch_size: batch size (default: 32)
Returns: Returns:
@ -41,7 +43,7 @@ def load_coco(data_path: str, data_type: str, num_epochs: int, batch_size: int =
""" """
annotation_file = f"{data_path}/annotations/instances_{data_type}.json" annotation_file = f"{data_path}/annotations/instances_{data_type}.json"
coco_interface = coco.COCO(annotation_file) coco_interface = coco.COCO(annotation_file)
img_ids = coco_interface.getImgIds() # return all image IDs img_ids = coco_interface.getImgIds(catIds=[category]) # return all image IDs belonging to given category
images = coco_interface.loadImgs(img_ids) # load all images images = coco_interface.loadImgs(img_ids) # load all images
annotation_ids = coco_interface.getAnnIds(img_ids) annotation_ids = coco_interface.getAnnIds(img_ids)
annotations = coco_interface.loadAnns(annotation_ids) # load all image annotations annotations = coco_interface.loadAnns(annotation_ids) # load all image annotations