diff --git a/src/twomartens/masterthesis/debug.py b/src/twomartens/masterthesis/debug.py index 68a6063..ae8763f 100644 --- a/src/twomartens/masterthesis/debug.py +++ b/src/twomartens/masterthesis/debug.py @@ -70,13 +70,27 @@ def save_ssd_train_images(images: np.ndarray, labels: np.ndarray, current_axis = pyplot.gca() for instance in instances: - xmin = (instance[-12]) * image_size - ymin = (instance[-11]) * image_size - xmax = (instance[-10]) * image_size - ymax = (instance[-9]) * image_size - class_id = np.argmax(instance[:-12], axis=0) + if len(instance) == 5: + class_id = instance[0] + xmin = instance[1] + ymin = instance[2] + xmax = instance[3] + ymax = instance[4] + else: + class_id = np.argmax(instance[:-12], axis=0) + xmin = instance[-12] + instance[-8] + ymin = instance[-11] + instance[-7] + xmax = instance[-10] + instance[-6] + ymax = instance[-9] + instance[-5] + if class_id == 0: continue + + xmin *= image_size + ymin *= image_size + xmax *= image_size + ymax *= image_size + color = colors[class_id] label = f"{classes_to_names[class_id]}" current_axis.add_patch(