From 7969716e6df521727bf3ed4387e3b795ef962162 Mon Sep 17 00:00:00 2001 From: Jim Martens Date: Wed, 10 Jul 2019 15:51:20 +0200 Subject: [PATCH] Made function compatible with encoded and decoded labels Signed-off-by: Jim Martens --- src/twomartens/masterthesis/debug.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/twomartens/masterthesis/debug.py b/src/twomartens/masterthesis/debug.py index 68a6063..ae8763f 100644 --- a/src/twomartens/masterthesis/debug.py +++ b/src/twomartens/masterthesis/debug.py @@ -70,13 +70,27 @@ def save_ssd_train_images(images: np.ndarray, labels: np.ndarray, current_axis = pyplot.gca() for instance in instances: - xmin = (instance[-12]) * image_size - ymin = (instance[-11]) * image_size - xmax = (instance[-10]) * image_size - ymax = (instance[-9]) * image_size - class_id = np.argmax(instance[:-12], axis=0) + if len(instance) == 5: + class_id = instance[0] + xmin = instance[1] + ymin = instance[2] + xmax = instance[3] + ymax = instance[4] + else: + class_id = np.argmax(instance[:-12], axis=0) + xmin = instance[-12] + instance[-8] + ymin = instance[-11] + instance[-7] + xmax = instance[-10] + instance[-6] + ymax = instance[-9] + instance[-5] + if class_id == 0: continue + + xmin *= image_size + ymin *= image_size + xmax *= image_size + ymax *= image_size + color = colors[class_id] label = f"{classes_to_names[class_id]}" current_axis.add_patch(