Added code to visualise metrics
For starters, only precision/recall is visualised Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -91,6 +91,12 @@ def visualise(args: argparse.Namespace) -> None:
|
||||
_visualise_gt(args, file_names, instances, cats_to_classes, cats_to_names, output_path)
|
||||
|
||||
|
||||
def visualise_metrics(args: argparse.Namespace) -> None:
|
||||
output_path, evaluation_path = _visualise_metrics_get_config_values(conf.get_property)
|
||||
output_path, metrics_file = _visualise_metrics_prepare_paths(args, output_path, evaluation_path)
|
||||
_visualise_metrics(_visualise_precision_recall, output_path, metrics_file)
|
||||
|
||||
|
||||
def measure_mapping(args: argparse.Namespace) -> None:
|
||||
from twomartens.masterthesis.ssd_keras.eval_utils import coco_utils
|
||||
|
||||
@ -415,6 +421,27 @@ def _visualise_gt(args: argparse.Namespace,
|
||||
i += 1
|
||||
|
||||
|
||||
def _visualise_metrics(visualise_precision_recall: callable,
|
||||
output_path: str,
|
||||
metrics_file: str) -> None:
|
||||
import pickle
|
||||
|
||||
with open(metrics_file, "rb") as file:
|
||||
metrics = pickle.load(file)
|
||||
|
||||
precision_micro = metrics["cumulative_precision_micro"]
|
||||
recall_micro = metrics["cumulative_recall_micro"]
|
||||
visualise_precision_recall(precision_micro, recall_micro,
|
||||
output_path, "micro")
|
||||
|
||||
precision_macro = metrics["cumulative_precision_macro"]
|
||||
recall_macro = metrics["cumulative_recall_macro"]
|
||||
visualise_precision_recall(precision_macro, recall_macro,
|
||||
output_path, "macro")
|
||||
|
||||
# TODO add further metrics
|
||||
|
||||
|
||||
def _init_eager_mode() -> None:
|
||||
tf.enable_eager_execution()
|
||||
|
||||
@ -620,6 +647,14 @@ def _visualise_get_config_values(config_get: Callable[[str], Union[str, int, flo
|
||||
return output_path, coco_path, ground_truth_path
|
||||
|
||||
|
||||
def _visualise_metrics_get_config_values(config_get: Callable[[str], Union[str, int, float, bool]]
|
||||
) -> Tuple[str, str]:
|
||||
output_path = config_get("Paths.output")
|
||||
evaluation_path = config_get("Paths.evaluation")
|
||||
|
||||
return output_path, evaluation_path
|
||||
|
||||
|
||||
def _ssd_is_dropout(args: argparse.Namespace) -> bool:
|
||||
return False if args.network == "ssd" else True
|
||||
|
||||
@ -703,6 +738,19 @@ def _visualise_prepare_paths(args: argparse.Namespace,
|
||||
return output_path, annotation_file_train
|
||||
|
||||
|
||||
def _visualise_metrics_prepare_paths(args: argparse.Namespace,
|
||||
output_path: str,
|
||||
evaluation_path: str) -> Tuple[str, str]:
|
||||
import os
|
||||
|
||||
metrics_file = f"{evaluation_path}/{args.network}/results-{args.iteration}.bin"
|
||||
output_path = f"{output_path}/{args.network}/visualise/{args.iteration}"
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
return output_path, metrics_file
|
||||
|
||||
|
||||
def _ssd_train_load_gt(train_gt_path: str, val_gt_path: str
|
||||
) -> Tuple[Sequence[Sequence[str]],
|
||||
Sequence[Sequence[Sequence[dict]]],
|
||||
@ -950,6 +998,20 @@ def _ssd_evaluate_get_results(true_positives: Sequence[np.ndarray],
|
||||
return results
|
||||
|
||||
|
||||
def _visualise_precision_recall(precision: np.ndarray, recall: np.ndarray,
|
||||
output_path: str, file_suffix: str) -> None:
|
||||
from matplotlib import pyplot
|
||||
|
||||
figure = pyplot.figure()
|
||||
|
||||
pyplot.ylabel("precision")
|
||||
pyplot.xlabel("recall")
|
||||
pyplot.plot(recall, precision)
|
||||
|
||||
pyplot.savefig(f"{output_path}/precision-recall-{file_suffix}.png")
|
||||
pyplot.close(figure)
|
||||
|
||||
|
||||
def _auto_encoder_train(args: argparse.Namespace) -> None:
|
||||
import os
|
||||
|
||||
|
||||
@ -43,7 +43,8 @@ def main() -> None:
|
||||
_build_test(sub_parsers[3])
|
||||
_build_evaluate(sub_parsers[4])
|
||||
_build_visualise(sub_parsers[5])
|
||||
_build_measure(sub_parsers[6])
|
||||
_build_visualise_metrics(sub_parsers[6])
|
||||
_build_measure(sub_parsers[7])
|
||||
|
||||
args = _get_user_input(parser)
|
||||
_execute_action(args,
|
||||
@ -53,6 +54,7 @@ def main() -> None:
|
||||
cli.test,
|
||||
cli.evaluate,
|
||||
cli.visualise,
|
||||
cli.visualise_metrics,
|
||||
cli.measure_mapping)
|
||||
|
||||
|
||||
@ -71,6 +73,7 @@ def _build_sub_parsers(parser: argparse.ArgumentParser) -> List[argparse.Argumen
|
||||
test_parser = sub_parsers.add_parser("test", help="Test a network")
|
||||
evaluate_parser = sub_parsers.add_parser("evaluate", help="Evaluate a network")
|
||||
visualise_parser = sub_parsers.add_parser("visualise", help="Visualise the ground truth")
|
||||
visualise_metrics_parser = sub_parsers.add_parser("visualise_metrics", help="Visualise the evaluation results")
|
||||
measure_parser = sub_parsers.add_parser("measure_mapping", help="Measure the number of instances per COCO category")
|
||||
|
||||
return [
|
||||
@ -80,6 +83,7 @@ def _build_sub_parsers(parser: argparse.ArgumentParser) -> List[argparse.Argumen
|
||||
test_parser,
|
||||
evaluate_parser,
|
||||
visualise_parser,
|
||||
visualise_metrics_parser,
|
||||
measure_parser
|
||||
]
|
||||
|
||||
@ -99,6 +103,7 @@ def _execute_action(args: argparse.Namespace,
|
||||
on_test: callable,
|
||||
on_evaluate: callable,
|
||||
on_visualise: callable,
|
||||
on_visualise_metrics: callable,
|
||||
on_measure: callable) -> None:
|
||||
if args.component == "config":
|
||||
on_config(args)
|
||||
@ -112,6 +117,8 @@ def _execute_action(args: argparse.Namespace,
|
||||
on_evaluate(args)
|
||||
elif args.component == "visualise":
|
||||
on_visualise(args)
|
||||
elif args.component == "visualise_metrics":
|
||||
on_visualise_metrics(args)
|
||||
elif args.component == "measure_mapping":
|
||||
on_measure(args)
|
||||
|
||||
@ -214,6 +221,10 @@ def _build_visualise(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("trajectory", type=int, help="trajectory to visualise")
|
||||
|
||||
|
||||
def _build_visualise_metrics(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("iteration", type=int, help="the validation iteration to use")
|
||||
|
||||
|
||||
def _build_measure(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("tarball_id", type=str, help="id of the used tarball. number for training tarball or 'test'")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user