diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index 1ac1c6e..ab199f2 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -138,15 +138,21 @@ def measure_mapping(args: argparse.Namespace) -> None: output_path, coco_path, ground_truth_path = _measure_get_config_values(conf.get_property) output_path, annotation_file_train = _measure_prepare_paths(args, output_path, coco_path) - instances, cats_to_classes = _measure_load_gt(ground_truth_path, annotation_file_train, - coco_utils.get_coco_category_maps) + instances, cats_to_classes, cats_to_names = _measure_load_gt(ground_truth_path, annotation_file_train, + coco_utils.get_coco_category_maps) nr_digits = _get_nr_digits(instances) - _measure(instances, cats_to_classes, nr_digits, output_path) + _measure(instances, cats_to_classes, cats_to_names, nr_digits, output_path) -def _measure(instances: Sequence[Sequence[Sequence[dict]]], cats_to_classes: Dict[int, int], +def _measure(instances: Sequence[Sequence[Sequence[dict]]], + cats_to_classes: Dict[int, int], + cats_to_names: Dict[int, str], nr_digits: int, output_path: str) -> None: import pickle + + with open(f"{output_path}/names.bin", "wb") as file: + pickle.dump(cats_to_names, file) + for i, trajectory in enumerate(instances): counts = {cat_id: 0 for cat_id in cats_to_classes.keys()} for labels in trajectory: @@ -658,14 +664,15 @@ def _ssd_test_load_gt(gt_path: str) -> Tuple[Sequence[Sequence[str]], def _measure_load_gt(gt_path: str, annotation_file_train: str, get_coco_cat_maps_func: callable) -> Tuple[Sequence[Sequence[Sequence[dict]]], - Dict[int, int]]: + Dict[int, int], + Dict[int, str]]: import pickle with open(f"{gt_path}/instances.bin", "rb") as file: instances = pickle.load(file) - cats_to_classes, _, _, _ = get_coco_cat_maps_func(annotation_file_train) + cats_to_classes, _, cats_to_names, _ = get_coco_cat_maps_func(annotation_file_train) - return instances, cats_to_classes + return instances, cats_to_classes, cats_to_names def _ssd_train_get_generators(args: argparse.Namespace,