Implemented evaluation for vanilla SSD
Uses the evaluate functions. Those follow ssd_keras very closely but were necessary since the Tensorflow data pipeline is used instead of the DataGenerator provided by ssd_keras. Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -22,6 +22,7 @@ Functions:
|
|||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""
|
"""
|
||||||
Provides command line interface.
|
Provides command line interface.
|
||||||
@ -163,9 +164,11 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||||||
import glob
|
import glob
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from twomartens.masterthesis import evaluate
|
||||||
|
from twomartens.masterthesis import ssd
|
||||||
|
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
|
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
@ -187,7 +190,38 @@ def _ssd_test(args: argparse.Namespace) -> None:
|
|||||||
with open(label_file, "wb") as file:
|
with open(label_file, "wb") as file:
|
||||||
pickle.dump(labels, file)
|
pickle.dump(labels, file)
|
||||||
|
|
||||||
# TODO implement evaluate.py analogous to average_precision_evaluator
|
number_gt_per_class = evaluate.get_number_gt_per_class(labels, ssd.N_CLASSES)
|
||||||
|
|
||||||
|
# retrieve predictions and un-batch them
|
||||||
|
files = glob.glob(f"{output_path}/*ssd_predictions*")
|
||||||
|
predictions = []
|
||||||
|
for filename in files:
|
||||||
|
with open(filename, "rb") as file:
|
||||||
|
# get predictions per batch
|
||||||
|
_predictions = pickle.load(file)
|
||||||
|
# select only forward pass
|
||||||
|
_predictions = _predictions[0]
|
||||||
|
predictions.extend(_predictions)
|
||||||
|
del _predictions
|
||||||
|
|
||||||
|
# prepare predictions for further use
|
||||||
|
predictions_per_class = evaluate.prepare_predictions(predictions, ssd.N_CLASSES)
|
||||||
|
del predictions
|
||||||
|
|
||||||
|
# compute matches between predictions and ground truth
|
||||||
|
true_positives, false_positives, \
|
||||||
|
cum_true_positives, cum_false_positives = evaluate.match_predictions(predictions_per_class,
|
||||||
|
labels,
|
||||||
|
ssd.N_CLASSES)
|
||||||
|
del labels
|
||||||
|
cum_precisions, cum_recalls = evaluate.get_precision_recall(number_gt_per_class,
|
||||||
|
cum_true_positives,
|
||||||
|
cum_false_positives,
|
||||||
|
ssd.N_CLASSES)
|
||||||
|
average_precisions = evaluate.get_mean_average_precisions(cum_precisions, cum_recalls, ssd.N_CLASSES)
|
||||||
|
mean_average_precision = evaluate.get_mean_average_precision(average_precisions)
|
||||||
|
|
||||||
|
# TODO store result of evaluation
|
||||||
|
|
||||||
|
|
||||||
def _val(args: argparse.Namespace) -> None:
|
def _val(args: argparse.Namespace) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user