diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index ca1bea9..dda77da 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -121,7 +121,7 @@ def _ssd_train(args: argparse.Namespace) -> None: ymin = instance[-11] * image_size xmax = instance[-10] * image_size ymax = instance[-9] * image_size - class_id = np.argmax(instance[:-12]) + class_id = np.argmax(instance[:-12], axis=0) color = colors[class_id] label = f"{classes_to_names[class_id]}" current_axis.add_patch(