diff --git a/src/twomartens/masterthesis/main.py b/src/twomartens/masterthesis/main.py index 9c8e13f..0f93e94 100644 --- a/src/twomartens/masterthesis/main.py +++ b/src/twomartens/masterthesis/main.py @@ -182,7 +182,8 @@ def _test(args: argparse.Namespace) -> None: def _ssd_test(args: argparse.Namespace) -> None: import glob import pickle - + + import numpy as np import tensorflow as tf from twomartens.masterthesis import evaluate @@ -207,7 +208,7 @@ def _ssd_test(args: argparse.Namespace) -> None: # get labels per batch _labels = pickle.load(file) # exclude padded label entries - real_labels = _labels[:, :, 0] != -1 + real_labels = np.nonzero(_labels[:, :, 0] != -1) labels.extend(_labels[real_labels]) # store labels for later use with open(label_file, "wb") as file: