Added code to visualise metrics

For starters, only precision/recall is visualised

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-23 13:14:22 +02:00
parent d6047665ef
commit 53472a8342
2 changed files with 74 additions and 1 deletions

View File

@ -91,6 +91,12 @@ def visualise(args: argparse.Namespace) -> None:
_visualise_gt(args, file_names, instances, cats_to_classes, cats_to_names, output_path) _visualise_gt(args, file_names, instances, cats_to_classes, cats_to_names, output_path)
def visualise_metrics(args: argparse.Namespace) -> None:
output_path, evaluation_path = _visualise_metrics_get_config_values(conf.get_property)
output_path, metrics_file = _visualise_metrics_prepare_paths(args, output_path, evaluation_path)
_visualise_metrics(_visualise_precision_recall, output_path, metrics_file)
def measure_mapping(args: argparse.Namespace) -> None: def measure_mapping(args: argparse.Namespace) -> None:
from twomartens.masterthesis.ssd_keras.eval_utils import coco_utils from twomartens.masterthesis.ssd_keras.eval_utils import coco_utils
@ -415,6 +421,27 @@ def _visualise_gt(args: argparse.Namespace,
i += 1 i += 1
def _visualise_metrics(visualise_precision_recall: callable,
output_path: str,
metrics_file: str) -> None:
import pickle
with open(metrics_file, "rb") as file:
metrics = pickle.load(file)
precision_micro = metrics["cumulative_precision_micro"]
recall_micro = metrics["cumulative_recall_micro"]
visualise_precision_recall(precision_micro, recall_micro,
output_path, "micro")
precision_macro = metrics["cumulative_precision_macro"]
recall_macro = metrics["cumulative_recall_macro"]
visualise_precision_recall(precision_macro, recall_macro,
output_path, "macro")
# TODO add further metrics
def _init_eager_mode() -> None: def _init_eager_mode() -> None:
tf.enable_eager_execution() tf.enable_eager_execution()
@ -620,6 +647,14 @@ def _visualise_get_config_values(config_get: Callable[[str], Union[str, int, flo
return output_path, coco_path, ground_truth_path return output_path, coco_path, ground_truth_path
def _visualise_metrics_get_config_values(config_get: Callable[[str], Union[str, int, float, bool]]
) -> Tuple[str, str]:
output_path = config_get("Paths.output")
evaluation_path = config_get("Paths.evaluation")
return output_path, evaluation_path
def _ssd_is_dropout(args: argparse.Namespace) -> bool: def _ssd_is_dropout(args: argparse.Namespace) -> bool:
return False if args.network == "ssd" else True return False if args.network == "ssd" else True
@ -703,6 +738,19 @@ def _visualise_prepare_paths(args: argparse.Namespace,
return output_path, annotation_file_train return output_path, annotation_file_train
def _visualise_metrics_prepare_paths(args: argparse.Namespace,
output_path: str,
evaluation_path: str) -> Tuple[str, str]:
import os
metrics_file = f"{evaluation_path}/{args.network}/results-{args.iteration}.bin"
output_path = f"{output_path}/{args.network}/visualise/{args.iteration}"
os.makedirs(output_path, exist_ok=True)
return output_path, metrics_file
def _ssd_train_load_gt(train_gt_path: str, val_gt_path: str def _ssd_train_load_gt(train_gt_path: str, val_gt_path: str
) -> Tuple[Sequence[Sequence[str]], ) -> Tuple[Sequence[Sequence[str]],
Sequence[Sequence[Sequence[dict]]], Sequence[Sequence[Sequence[dict]]],
@ -950,6 +998,20 @@ def _ssd_evaluate_get_results(true_positives: Sequence[np.ndarray],
return results return results
def _visualise_precision_recall(precision: np.ndarray, recall: np.ndarray,
output_path: str, file_suffix: str) -> None:
from matplotlib import pyplot
figure = pyplot.figure()
pyplot.ylabel("precision")
pyplot.xlabel("recall")
pyplot.plot(recall, precision)
pyplot.savefig(f"{output_path}/precision-recall-{file_suffix}.png")
pyplot.close(figure)
def _auto_encoder_train(args: argparse.Namespace) -> None: def _auto_encoder_train(args: argparse.Namespace) -> None:
import os import os

View File

@ -43,7 +43,8 @@ def main() -> None:
_build_test(sub_parsers[3]) _build_test(sub_parsers[3])
_build_evaluate(sub_parsers[4]) _build_evaluate(sub_parsers[4])
_build_visualise(sub_parsers[5]) _build_visualise(sub_parsers[5])
_build_measure(sub_parsers[6]) _build_visualise_metrics(sub_parsers[6])
_build_measure(sub_parsers[7])
args = _get_user_input(parser) args = _get_user_input(parser)
_execute_action(args, _execute_action(args,
@ -53,6 +54,7 @@ def main() -> None:
cli.test, cli.test,
cli.evaluate, cli.evaluate,
cli.visualise, cli.visualise,
cli.visualise_metrics,
cli.measure_mapping) cli.measure_mapping)
@ -71,6 +73,7 @@ def _build_sub_parsers(parser: argparse.ArgumentParser) -> List[argparse.Argumen
test_parser = sub_parsers.add_parser("test", help="Test a network") test_parser = sub_parsers.add_parser("test", help="Test a network")
evaluate_parser = sub_parsers.add_parser("evaluate", help="Evaluate a network") evaluate_parser = sub_parsers.add_parser("evaluate", help="Evaluate a network")
visualise_parser = sub_parsers.add_parser("visualise", help="Visualise the ground truth") visualise_parser = sub_parsers.add_parser("visualise", help="Visualise the ground truth")
visualise_metrics_parser = sub_parsers.add_parser("visualise_metrics", help="Visualise the evaluation results")
measure_parser = sub_parsers.add_parser("measure_mapping", help="Measure the number of instances per COCO category") measure_parser = sub_parsers.add_parser("measure_mapping", help="Measure the number of instances per COCO category")
return [ return [
@ -80,6 +83,7 @@ def _build_sub_parsers(parser: argparse.ArgumentParser) -> List[argparse.Argumen
test_parser, test_parser,
evaluate_parser, evaluate_parser,
visualise_parser, visualise_parser,
visualise_metrics_parser,
measure_parser measure_parser
] ]
@ -99,6 +103,7 @@ def _execute_action(args: argparse.Namespace,
on_test: callable, on_test: callable,
on_evaluate: callable, on_evaluate: callable,
on_visualise: callable, on_visualise: callable,
on_visualise_metrics: callable,
on_measure: callable) -> None: on_measure: callable) -> None:
if args.component == "config": if args.component == "config":
on_config(args) on_config(args)
@ -112,6 +117,8 @@ def _execute_action(args: argparse.Namespace,
on_evaluate(args) on_evaluate(args)
elif args.component == "visualise": elif args.component == "visualise":
on_visualise(args) on_visualise(args)
elif args.component == "visualise_metrics":
on_visualise_metrics(args)
elif args.component == "measure_mapping": elif args.component == "measure_mapping":
on_measure(args) on_measure(args)
@ -214,6 +221,10 @@ def _build_visualise(parser: argparse.ArgumentParser) -> None:
parser.add_argument("trajectory", type=int, help="trajectory to visualise") parser.add_argument("trajectory", type=int, help="trajectory to visualise")
def _build_visualise_metrics(parser: argparse.ArgumentParser) -> None:
parser.add_argument("iteration", type=int, help="the validation iteration to use")
def _build_measure(parser: argparse.ArgumentParser) -> None: def _build_measure(parser: argparse.ArgumentParser) -> None:
parser.add_argument("tarball_id", type=str, help="id of the used tarball. number for training tarball or 'test'") parser.add_argument("tarball_id", type=str, help="id of the used tarball. number for training tarball or 'test'")