diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index b70ec51..79a9061 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -399,3 +399,25 @@ def visualise(args: argparse.Namespace) -> None: pyplot.close(figure) i += 1 + + +def measure_mapping(args: argparse.Namespace) -> None: + import pickle + + from twomartens.masterthesis.ssd_keras.eval_utils import coco_utils + + with open(f"{args.ground_truth_path}/instances.bin", "rb") as file: + instances = pickle.load(file) + + output_path = f"{args.output_path}/measure/{args.tarball_id}" + annotation_file_train = f"{args.coco_path}/annotations/instances_train2014.json" + cats_to_classes, _, _, _ = coco_utils.get_coco_category_maps(annotation_file_train) + + for i, trajectory in enumerate(instances): + counts = {cat_id: 0 for cat_id in cats_to_classes.keys()} + for labels in trajectory: + for instance in labels: + counts[instance['coco_id']] += 1 + + with open(f"{output_path}/{i}.bin", "wb") as file: + pickle.dump(counts, file) diff --git a/src/twomartens/masterthesis/main.py b/src/twomartens/masterthesis/main.py index 8d1ec06..8164f60 100644 --- a/src/twomartens/masterthesis/main.py +++ b/src/twomartens/masterthesis/main.py @@ -44,6 +44,7 @@ def main() -> None: evaluate_parser = sub_parsers.add_parser("evaluate", help="Evaluate a network") test_parser = sub_parsers.add_parser("test", help="Test a network") visualise_parser = sub_parsers.add_parser("visualise", help="Visualise the ground truth") + measure_parser = sub_parsers.add_parser("measure_mapping", help="Measure the number of instances per COCO category") # build sub parsers _build_prepare(prepare_parser) @@ -51,6 +52,7 @@ def main() -> None: _build_test(test_parser) _build_evaluate(evaluate_parser) _build_visualise(visualise_parser) + _build_measure(measure_parser) args = parser.parse_args() @@ -64,6 +66,8 @@ def main() -> None: cli.prepare(args) elif args.action == "visualise": cli.visualise(args) + elif args.action == "measure_mapping": + cli.measure_mapping(args) def _build_prepare(parser: argparse.ArgumentParser) -> None: @@ -166,5 +170,12 @@ def _build_visualise(parser: argparse.ArgumentParser) -> None: parser.add_argument("trajectory", type=int, help="trajectory to visualise") +def _build_measure(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--coco_path", type=str, help="the path to the COCO data set") + parser.add_argument("--ground_truth_path", type=str, help="path to the prepared ground truth directory") + parser.add_argument("--output_path", type=str, help="path to the output directory") + parser.add_argument("tarball_id", type=int, help="id of the used tarball") + + if __name__ == "__main__": main()