Finished vanilla SSD evaluation code
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -174,7 +174,11 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||||||
batch_size = 16
|
batch_size = 16
|
||||||
use_dropout = False if args.network == "ssd" else True
|
use_dropout = False if args.network == "ssd" else True
|
||||||
output_path = f"{args.output_path}/val/{args.network}/{args.iteration}"
|
output_path = f"{args.output_path}/val/{args.network}/{args.iteration}"
|
||||||
|
evaluation_path = f"{args.evaluation_path}/{args.network}"
|
||||||
|
result_file = f"{evaluation_path}/results-{args.iteration}.bin"
|
||||||
label_file = f"{output_path}/labels.bin"
|
label_file = f"{output_path}/labels.bin"
|
||||||
|
predictions_file = f"{output_path}/predictions.bin"
|
||||||
|
predictions_per_class_file = f"{output_path}/predictions_class.bin"
|
||||||
|
|
||||||
# retrieve labels and un-batch them
|
# retrieve labels and un-batch them
|
||||||
files = glob.glob(f"{output_path}/*ssd_labels*")
|
files = glob.glob(f"{output_path}/*ssd_labels*")
|
||||||
@ -205,9 +209,15 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||||||
del _predictions
|
del _predictions
|
||||||
|
|
||||||
# prepare predictions for further use
|
# prepare predictions for further use
|
||||||
|
with open(predictions_file, "wb") as file:
|
||||||
|
pickle.dump(predictions, file)
|
||||||
|
|
||||||
predictions_per_class = evaluate.prepare_predictions(predictions, ssd.N_CLASSES)
|
predictions_per_class = evaluate.prepare_predictions(predictions, ssd.N_CLASSES)
|
||||||
del predictions
|
del predictions
|
||||||
|
|
||||||
|
with open(predictions_per_class_file, "wb") as file:
|
||||||
|
pickle.dump(predictions_per_class, file)
|
||||||
|
|
||||||
# compute matches between predictions and ground truth
|
# compute matches between predictions and ground truth
|
||||||
true_positives, false_positives, \
|
true_positives, false_positives, \
|
||||||
cum_true_positives, cum_false_positives = evaluate.match_predictions(predictions_per_class,
|
cum_true_positives, cum_false_positives = evaluate.match_predictions(predictions_per_class,
|
||||||
@ -221,7 +231,19 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||||||
average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES)
|
average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES)
|
||||||
mean_average_precision = evaluate.get_mean_average_precision(average_precisions)
|
mean_average_precision = evaluate.get_mean_average_precision(average_precisions)
|
||||||
|
|
||||||
# TODO store result of evaluation
|
results = {
|
||||||
|
"true_positives": true_positives,
|
||||||
|
"false_positives": false_positives,
|
||||||
|
"cumulative_true_positives": cum_true_positives,
|
||||||
|
"cumulative_false_positives": cum_false_positives,
|
||||||
|
"cumulative_precisions": cum_precisions,
|
||||||
|
"cumulative_recalls": cum_recalls,
|
||||||
|
"average precisions": average_precisions,
|
||||||
|
"mean_average_precision": mean_average_precision
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(result_file, "wb") as file:
|
||||||
|
pickle.dump(results, file)
|
||||||
|
|
||||||
|
|
||||||
def _val(args: argparse.Namespace) -> None:
|
def _val(args: argparse.Namespace) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user