diff --git a/src/twomartens/masterthesis/data.py b/src/twomartens/masterthesis/data.py index 98f14c2..2e6e921 100644 --- a/src/twomartens/masterthesis/data.py +++ b/src/twomartens/masterthesis/data.py @@ -27,28 +27,42 @@ import tensorflow as tf from pycocotools import coco -def load_coco(data_path: str, data_type: str, category: int, +def load_coco(data_path: str, category: int, num_epochs: int, batch_size: int = 32) -> tf.data.Dataset: """ - Loads the COCO data and returns a data set. + Loads the COCO trainval35k data and returns a data set. Args: data_path: path to the COCO data set - data_type: type of the COCO data (e.g. 'val2014') category: id of the inlying class num_epochs: number of epochs batch_size: batch size (default: 32) Returns: Tensorflow data set """ - annotation_file = f"{data_path}/annotations/instances_{data_type}.json" - coco_interface = coco.COCO(annotation_file) - img_ids = coco_interface.getImgIds(catIds=[category]) # return all image IDs belonging to given category - images = coco_interface.loadImgs(img_ids) # load all images - annotation_ids = coco_interface.getAnnIds(img_ids) - annotations = coco_interface.loadAnns(annotation_ids) # load all image annotations - file_names = [image['file_name'] for image in images] + annotation_file_train = f"{data_path}/annotations/instances_train2014.json" + annotation_file_val = f"{data_path}/annotations/instances_valminusminival2014.json" + # load training images + coco_train = coco.COCO(annotation_file_train) + img_ids = coco_train.getImgIds(catIds=[category]) # return all image IDs belonging to given category + images = coco_train.loadImgs(img_ids) # load all images + annotation_ids = coco_train.getAnnIds(img_ids) + annotations = coco_train.loadAnns(annotation_ids) # load all image annotations + file_names = [f"train2014/{image['file_name']}" for image in images] cat_ids = [annotation['category_id'] for annotation in annotations] + + # load validation images + coco_val = coco.COCO(annotation_file_val) + img_ids = coco_val.getImgIds(catIds=[category]) # return all image IDs belonging to given category + images = coco_val.loadImgs(img_ids) # load all images + annotation_ids = coco_val.getAnnIds(img_ids) + annotations = coco_val.loadAnns(annotation_ids) # load all image annotations + file_names_val = [f"val2014/{image['file_name']}" for image in images] + cat_ids_val = [annotation['category_id'] for annotation in annotations] + + file_names.extend(file_names_val) + cat_ids.extend(cat_ids_val) + length_dataset = len(file_names) def _load_image(image_data: Tuple[str, int]):