Added code to save training images with bounding boxes

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-02 15:45:04 +02:00
parent 112dc48f36
commit 22aa463152

View File

@ -85,13 +85,55 @@ def _ssd_train(args: argparse.Namespace) -> None:
del file_names_train, instances_train, file_names_val, instances_val del file_names_train, instances_train, file_names_val, instances_val
if args.debug: if args.debug:
from matplotlib import pyplot
import numpy as np
from PIL import Image
from twomartens.masterthesis.ssd_keras.eval_utils import coco_utils
train_data = next(train_generator) train_data = next(train_generator)
train_length -= batch_size train_length -= batch_size
from PIL import Image
train_images = train_data[0] train_images = train_data[0]
train_labels = train_data[1]
annotation_file_train = f"{args.coco_path}/annotations/instances_train2014.json"
_, _, _, classes_to_names = coco_utils.get_coco_category_maps(annotation_file_train)
colors = pyplot.cm.hsv(np.linspace(0, 1, 81)).tolist()
nr_images = len(train_images)
nr_digits = math.ceil(math.log10(nr_images))
for i, train_image in enumerate(train_images): for i, train_image in enumerate(train_images):
instances = train_labels[i]
image = Image.fromarray(train_image) image = Image.fromarray(train_image)
image.save(f"{args.summary_path}/train/{args.network}/{args.iteration}/train_image{i}.png") image.save(f"{args.summary_path}/train/{args.network}/{args.iteration}/"
f"train_image{str(i).zfill(nr_digits)}.png")
if not instances:
continue
figure = pyplot.figure(figsize=(20, 12))
pyplot.imshow(image)
current_axis = pyplot.gca()
for instance in instances:
xmin = instance[-12] * image_size
ymin = instance[-11] * image_size
xmax = instance[-10] * image_size
ymax = instance[-9] * image_size
class_id = np.argmax(instance[:-12])
print(class_id)
color = colors[class_id]
label = f"{classes_to_names[class_id]}"
current_axis.add_patch(
pyplot.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color=color, fill=False,
linewidth=2))
current_axis.text(xmin, ymin, label, size='x-large', color='white',
bbox={'facecolor': color, 'alpha': 1.0})
pyplot.savefig(f"{args.summary_path}/train/{args.network}/{args.iteration}/bboxes{str(i).zfill(nr_digits)}.png")
pyplot.close(figure)
nr_batches_train = int(math.floor(train_length / batch_size)) nr_batches_train = int(math.floor(train_length / batch_size))
nr_batches_val = int(math.floor(val_length / batch_size)) nr_batches_val = int(math.floor(val_length / batch_size))