diff --git a/src/twomartens/masterthesis/data.py b/src/twomartens/masterthesis/data.py index 5fadcf3..98f14c2 100644 --- a/src/twomartens/masterthesis/data.py +++ b/src/twomartens/masterthesis/data.py @@ -27,13 +27,15 @@ import tensorflow as tf from pycocotools import coco -def load_coco(data_path: str, data_type: str, num_epochs: int, batch_size: int = 32) -> tf.data.Dataset: +def load_coco(data_path: str, data_type: str, category: int, + num_epochs: int, batch_size: int = 32) -> tf.data.Dataset: """ Loads the COCO data and returns a data set. Args: data_path: path to the COCO data set data_type: type of the COCO data (e.g. 'val2014') + category: id of the inlying class num_epochs: number of epochs batch_size: batch size (default: 32) Returns: @@ -41,7 +43,7 @@ def load_coco(data_path: str, data_type: str, num_epochs: int, batch_size: int = """ annotation_file = f"{data_path}/annotations/instances_{data_type}.json" coco_interface = coco.COCO(annotation_file) - img_ids = coco_interface.getImgIds() # return all image IDs + img_ids = coco_interface.getImgIds(catIds=[category]) # return all image IDs belonging to given category images = coco_interface.loadImgs(img_ids) # load all images annotation_ids = coco_interface.getAnnIds(img_ids) annotations = coco_interface.loadAnns(annotation_ids) # load all image annotations