diff --git a/src/twomartens/masterthesis/debug.py b/src/twomartens/masterthesis/debug.py index acc1687..c7ca527 100644 --- a/src/twomartens/masterthesis/debug.py +++ b/src/twomartens/masterthesis/debug.py @@ -70,12 +70,18 @@ def save_ssd_train_images(images: np.ndarray, labels: np.ndarray, current_axis = pyplot.gca() for instance in instances: - if len(instance) == 5: + if len(instance) == 5: # ground truth class_id = int(instance[0]) xmin = instance[1] ymin = instance[2] xmax = instance[3] ymax = instance[4] + elif len(instance) == 7: # predictions + class_id = int(instance[0]) + xmin = instance[3] + ymin = instance[4] + xmax = instance[5] + ymax = instance[6] else: instance = np.asarray(instance) class_id = np.argmax(instance[:-12], axis=0)