diff --git a/src/twomartens/masterthesis/cli.py b/src/twomartens/masterthesis/cli.py index 59ce0f1..223e4f6 100644 --- a/src/twomartens/masterthesis/cli.py +++ b/src/twomartens/masterthesis/cli.py @@ -85,13 +85,55 @@ def _ssd_train(args: argparse.Namespace) -> None: del file_names_train, instances_train, file_names_val, instances_val if args.debug: + from matplotlib import pyplot + import numpy as np + from PIL import Image + + from twomartens.masterthesis.ssd_keras.eval_utils import coco_utils + train_data = next(train_generator) train_length -= batch_size - from PIL import Image train_images = train_data[0] + train_labels = train_data[1] + + annotation_file_train = f"{args.coco_path}/annotations/instances_train2014.json" + _, _, _, classes_to_names = coco_utils.get_coco_category_maps(annotation_file_train) + colors = pyplot.cm.hsv(np.linspace(0, 1, 81)).tolist() + + nr_images = len(train_images) + nr_digits = math.ceil(math.log10(nr_images)) + for i, train_image in enumerate(train_images): + instances = train_labels[i] image = Image.fromarray(train_image) - image.save(f"{args.summary_path}/train/{args.network}/{args.iteration}/train_image{i}.png") + image.save(f"{args.summary_path}/train/{args.network}/{args.iteration}/" + f"train_image{str(i).zfill(nr_digits)}.png") + + if not instances: + continue + + figure = pyplot.figure(figsize=(20, 12)) + pyplot.imshow(image) + + current_axis = pyplot.gca() + + for instance in instances: + xmin = instance[-12] * image_size + ymin = instance[-11] * image_size + xmax = instance[-10] * image_size + ymax = instance[-9] * image_size + class_id = np.argmax(instance[:-12]) + print(class_id) + color = colors[class_id] + label = f"{classes_to_names[class_id]}" + current_axis.add_patch( + pyplot.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color=color, fill=False, + linewidth=2)) + current_axis.text(xmin, ymin, label, size='x-large', color='white', + bbox={'facecolor': color, 'alpha': 1.0}) + pyplot.savefig(f"{args.summary_path}/train/{args.network}/{args.iteration}/bboxes{str(i).zfill(nr_digits)}.png") + pyplot.close(figure) + nr_batches_train = int(math.floor(train_length / batch_size)) nr_batches_val = int(math.floor(val_length / batch_size))