Added code to visualise metrics
For starters, only precision/recall is visualised Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -91,6 +91,12 @@ def visualise(args: argparse.Namespace) -> None:
|
|||||||
_visualise_gt(args, file_names, instances, cats_to_classes, cats_to_names, output_path)
|
_visualise_gt(args, file_names, instances, cats_to_classes, cats_to_names, output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def visualise_metrics(args: argparse.Namespace) -> None:
|
||||||
|
output_path, evaluation_path = _visualise_metrics_get_config_values(conf.get_property)
|
||||||
|
output_path, metrics_file = _visualise_metrics_prepare_paths(args, output_path, evaluation_path)
|
||||||
|
_visualise_metrics(_visualise_precision_recall, output_path, metrics_file)
|
||||||
|
|
||||||
|
|
||||||
def measure_mapping(args: argparse.Namespace) -> None:
|
def measure_mapping(args: argparse.Namespace) -> None:
|
||||||
from twomartens.masterthesis.ssd_keras.eval_utils import coco_utils
|
from twomartens.masterthesis.ssd_keras.eval_utils import coco_utils
|
||||||
|
|
||||||
@ -415,6 +421,27 @@ def _visualise_gt(args: argparse.Namespace,
|
|||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _visualise_metrics(visualise_precision_recall: callable,
|
||||||
|
output_path: str,
|
||||||
|
metrics_file: str) -> None:
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
with open(metrics_file, "rb") as file:
|
||||||
|
metrics = pickle.load(file)
|
||||||
|
|
||||||
|
precision_micro = metrics["cumulative_precision_micro"]
|
||||||
|
recall_micro = metrics["cumulative_recall_micro"]
|
||||||
|
visualise_precision_recall(precision_micro, recall_micro,
|
||||||
|
output_path, "micro")
|
||||||
|
|
||||||
|
precision_macro = metrics["cumulative_precision_macro"]
|
||||||
|
recall_macro = metrics["cumulative_recall_macro"]
|
||||||
|
visualise_precision_recall(precision_macro, recall_macro,
|
||||||
|
output_path, "macro")
|
||||||
|
|
||||||
|
# TODO add further metrics
|
||||||
|
|
||||||
|
|
||||||
def _init_eager_mode() -> None:
|
def _init_eager_mode() -> None:
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
|
|
||||||
@ -620,6 +647,14 @@ def _visualise_get_config_values(config_get: Callable[[str], Union[str, int, flo
|
|||||||
return output_path, coco_path, ground_truth_path
|
return output_path, coco_path, ground_truth_path
|
||||||
|
|
||||||
|
|
||||||
|
def _visualise_metrics_get_config_values(config_get: Callable[[str], Union[str, int, float, bool]]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
output_path = config_get("Paths.output")
|
||||||
|
evaluation_path = config_get("Paths.evaluation")
|
||||||
|
|
||||||
|
return output_path, evaluation_path
|
||||||
|
|
||||||
|
|
||||||
def _ssd_is_dropout(args: argparse.Namespace) -> bool:
|
def _ssd_is_dropout(args: argparse.Namespace) -> bool:
|
||||||
return False if args.network == "ssd" else True
|
return False if args.network == "ssd" else True
|
||||||
|
|
||||||
@ -703,6 +738,19 @@ def _visualise_prepare_paths(args: argparse.Namespace,
|
|||||||
return output_path, annotation_file_train
|
return output_path, annotation_file_train
|
||||||
|
|
||||||
|
|
||||||
|
def _visualise_metrics_prepare_paths(args: argparse.Namespace,
|
||||||
|
output_path: str,
|
||||||
|
evaluation_path: str) -> Tuple[str, str]:
|
||||||
|
import os
|
||||||
|
|
||||||
|
metrics_file = f"{evaluation_path}/{args.network}/results-{args.iteration}.bin"
|
||||||
|
output_path = f"{output_path}/{args.network}/visualise/{args.iteration}"
|
||||||
|
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
return output_path, metrics_file
|
||||||
|
|
||||||
|
|
||||||
def _ssd_train_load_gt(train_gt_path: str, val_gt_path: str
|
def _ssd_train_load_gt(train_gt_path: str, val_gt_path: str
|
||||||
) -> Tuple[Sequence[Sequence[str]],
|
) -> Tuple[Sequence[Sequence[str]],
|
||||||
Sequence[Sequence[Sequence[dict]]],
|
Sequence[Sequence[Sequence[dict]]],
|
||||||
@ -950,6 +998,20 @@ def _ssd_evaluate_get_results(true_positives: Sequence[np.ndarray],
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _visualise_precision_recall(precision: np.ndarray, recall: np.ndarray,
|
||||||
|
output_path: str, file_suffix: str) -> None:
|
||||||
|
from matplotlib import pyplot
|
||||||
|
|
||||||
|
figure = pyplot.figure()
|
||||||
|
|
||||||
|
pyplot.ylabel("precision")
|
||||||
|
pyplot.xlabel("recall")
|
||||||
|
pyplot.plot(recall, precision)
|
||||||
|
|
||||||
|
pyplot.savefig(f"{output_path}/precision-recall-{file_suffix}.png")
|
||||||
|
pyplot.close(figure)
|
||||||
|
|
||||||
|
|
||||||
def _auto_encoder_train(args: argparse.Namespace) -> None:
|
def _auto_encoder_train(args: argparse.Namespace) -> None:
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|||||||
@ -43,7 +43,8 @@ def main() -> None:
|
|||||||
_build_test(sub_parsers[3])
|
_build_test(sub_parsers[3])
|
||||||
_build_evaluate(sub_parsers[4])
|
_build_evaluate(sub_parsers[4])
|
||||||
_build_visualise(sub_parsers[5])
|
_build_visualise(sub_parsers[5])
|
||||||
_build_measure(sub_parsers[6])
|
_build_visualise_metrics(sub_parsers[6])
|
||||||
|
_build_measure(sub_parsers[7])
|
||||||
|
|
||||||
args = _get_user_input(parser)
|
args = _get_user_input(parser)
|
||||||
_execute_action(args,
|
_execute_action(args,
|
||||||
@ -53,6 +54,7 @@ def main() -> None:
|
|||||||
cli.test,
|
cli.test,
|
||||||
cli.evaluate,
|
cli.evaluate,
|
||||||
cli.visualise,
|
cli.visualise,
|
||||||
|
cli.visualise_metrics,
|
||||||
cli.measure_mapping)
|
cli.measure_mapping)
|
||||||
|
|
||||||
|
|
||||||
@ -71,6 +73,7 @@ def _build_sub_parsers(parser: argparse.ArgumentParser) -> List[argparse.Argumen
|
|||||||
test_parser = sub_parsers.add_parser("test", help="Test a network")
|
test_parser = sub_parsers.add_parser("test", help="Test a network")
|
||||||
evaluate_parser = sub_parsers.add_parser("evaluate", help="Evaluate a network")
|
evaluate_parser = sub_parsers.add_parser("evaluate", help="Evaluate a network")
|
||||||
visualise_parser = sub_parsers.add_parser("visualise", help="Visualise the ground truth")
|
visualise_parser = sub_parsers.add_parser("visualise", help="Visualise the ground truth")
|
||||||
|
visualise_metrics_parser = sub_parsers.add_parser("visualise_metrics", help="Visualise the evaluation results")
|
||||||
measure_parser = sub_parsers.add_parser("measure_mapping", help="Measure the number of instances per COCO category")
|
measure_parser = sub_parsers.add_parser("measure_mapping", help="Measure the number of instances per COCO category")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@ -80,6 +83,7 @@ def _build_sub_parsers(parser: argparse.ArgumentParser) -> List[argparse.Argumen
|
|||||||
test_parser,
|
test_parser,
|
||||||
evaluate_parser,
|
evaluate_parser,
|
||||||
visualise_parser,
|
visualise_parser,
|
||||||
|
visualise_metrics_parser,
|
||||||
measure_parser
|
measure_parser
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -99,6 +103,7 @@ def _execute_action(args: argparse.Namespace,
|
|||||||
on_test: callable,
|
on_test: callable,
|
||||||
on_evaluate: callable,
|
on_evaluate: callable,
|
||||||
on_visualise: callable,
|
on_visualise: callable,
|
||||||
|
on_visualise_metrics: callable,
|
||||||
on_measure: callable) -> None:
|
on_measure: callable) -> None:
|
||||||
if args.component == "config":
|
if args.component == "config":
|
||||||
on_config(args)
|
on_config(args)
|
||||||
@ -112,6 +117,8 @@ def _execute_action(args: argparse.Namespace,
|
|||||||
on_evaluate(args)
|
on_evaluate(args)
|
||||||
elif args.component == "visualise":
|
elif args.component == "visualise":
|
||||||
on_visualise(args)
|
on_visualise(args)
|
||||||
|
elif args.component == "visualise_metrics":
|
||||||
|
on_visualise_metrics(args)
|
||||||
elif args.component == "measure_mapping":
|
elif args.component == "measure_mapping":
|
||||||
on_measure(args)
|
on_measure(args)
|
||||||
|
|
||||||
@ -214,6 +221,10 @@ def _build_visualise(parser: argparse.ArgumentParser) -> None:
|
|||||||
parser.add_argument("trajectory", type=int, help="trajectory to visualise")
|
parser.add_argument("trajectory", type=int, help="trajectory to visualise")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_visualise_metrics(parser: argparse.ArgumentParser) -> None:
|
||||||
|
parser.add_argument("iteration", type=int, help="the validation iteration to use")
|
||||||
|
|
||||||
|
|
||||||
def _build_measure(parser: argparse.ArgumentParser) -> None:
|
def _build_measure(parser: argparse.ArgumentParser) -> None:
|
||||||
parser.add_argument("tarball_id", type=str, help="id of the used tarball. number for training tarball or 'test'")
|
parser.add_argument("tarball_id", type=str, help="id of the used tarball. number for training tarball or 'test'")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user