Save category to names dictionary as well

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
Jim Martens 2019-07-15 17:20:35 +02:00
parent 83f70b4fdb
commit 13a3cb9170
1 changed files with 14 additions and 7 deletions

View File

@ -138,15 +138,21 @@ def measure_mapping(args: argparse.Namespace) -> None:
output_path, coco_path, ground_truth_path = _measure_get_config_values(conf.get_property)
output_path, annotation_file_train = _measure_prepare_paths(args, output_path, coco_path)
instances, cats_to_classes = _measure_load_gt(ground_truth_path, annotation_file_train,
coco_utils.get_coco_category_maps)
instances, cats_to_classes, cats_to_names = _measure_load_gt(ground_truth_path, annotation_file_train,
coco_utils.get_coco_category_maps)
nr_digits = _get_nr_digits(instances)
_measure(instances, cats_to_classes, nr_digits, output_path)
_measure(instances, cats_to_classes, cats_to_names, nr_digits, output_path)
def _measure(instances: Sequence[Sequence[Sequence[dict]]], cats_to_classes: Dict[int, int],
def _measure(instances: Sequence[Sequence[Sequence[dict]]],
cats_to_classes: Dict[int, int],
cats_to_names: Dict[int, str],
nr_digits: int, output_path: str) -> None:
import pickle
with open(f"{output_path}/names.bin", "wb") as file:
pickle.dump(cats_to_names, file)
for i, trajectory in enumerate(instances):
counts = {cat_id: 0 for cat_id in cats_to_classes.keys()}
for labels in trajectory:
@ -658,14 +664,15 @@ def _ssd_test_load_gt(gt_path: str) -> Tuple[Sequence[Sequence[str]],
def _measure_load_gt(gt_path: str, annotation_file_train: str,
get_coco_cat_maps_func: callable) -> Tuple[Sequence[Sequence[Sequence[dict]]],
Dict[int, int]]:
Dict[int, int],
Dict[int, str]]:
import pickle
with open(f"{gt_path}/instances.bin", "rb") as file:
instances = pickle.load(file)
cats_to_classes, _, _, _ = get_coco_cat_maps_func(annotation_file_train)
cats_to_classes, _, cats_to_names, _ = get_coco_cat_maps_func(annotation_file_train)
return instances, cats_to_classes
return instances, cats_to_classes, cats_to_names
def _ssd_train_get_generators(args: argparse.Namespace,