Moved code into separate functions

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-09-16 11:02:08 +02:00
parent 8230488031
commit 554ccc84f9

View File

@ -22,10 +22,12 @@ Functions:
save_ssd_train_images(...):
saves the first batch of SSD train images with overlaid ground truth bounding boxes
"""
import functools
import os
from typing import Dict
from typing import Tuple
from typing import Union, Sequence
import cv2
import math
import numpy as np
from matplotlib import pyplot
@ -68,13 +70,58 @@ def save_ssd_train_images(images: Union[np.ndarray, Sequence[str]], labels: np.n
image = Image.fromarray(train_image)
image.save(f"{output_path}/"
f"{custom_string}train_image{str(i).zfill(nr_digits)}.png")
figure_filename = f"{output_path}/{custom_string}bboxes{str(i).zfill(nr_digits)}.png"
_draw_bbox_image(image=image,
filename=figure_filename,
draw_func=functools.partial(_draw_bboxes,
image_size=image_size,
classes_to_names=classes_to_names),
drawables=[
(colors, instances)
])
def _draw_bbox_image(image: Image,
filename: str,
draw_func: callable,
drawables: Sequence[Tuple[Sequence, Sequence[np.ndarray]]]):
figure = pyplot.figure(figsize=(6.4, 4.8))
pyplot.imshow(image)
current_axis = pyplot.gca()
for colors, instances in drawables:
draw_func(instances=instances,
axis=current_axis,
colors=colors)
pyplot.savefig(filename)
pyplot.close(figure)
def _draw_bboxes(instances: Sequence[np.ndarray], axis: pyplot.Axes,
image_size: int,
colors: Sequence,
classes_to_names: Dict[int, str]) -> None:
for instance in instances:
if not len(instance):
continue
else:
class_id, xmin, ymin, xmax, ymax = _get_bbox_info(instance, image_size)
if class_id == 0:
continue
color = colors[class_id]
label = f"{classes_to_names[class_id]}"
axis.add_patch(
pyplot.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color=color, fill=False,
linewidth=2))
axis.text(xmin, ymin, label, size='x-large', color='white',
bbox={'facecolor': color, 'alpha': 1.0})
def _get_bbox_info(instance: np.ndarray, image_size: int) -> Tuple[int, float, float, float, float]:
if len(instance) == 5: # ground truth
class_id = int(instance[0])
xmin = instance[1]
@ -93,8 +140,6 @@ def save_ssd_train_images(images: Union[np.ndarray, Sequence[str]], labels: np.n
ymin = instance[3]
xmax = instance[4]
ymax = instance[5]
elif not len(instance):
continue
else:
instance = np.copy(instance)
class_id = np.argmax(instance[:-12], axis=0)
@ -109,15 +154,4 @@ def save_ssd_train_images(images: Union[np.ndarray, Sequence[str]], labels: np.n
xmax = instance[-10]
ymax = instance[-9]
if class_id == 0:
continue
color = colors[class_id]
label = f"{classes_to_names[class_id]}"
current_axis.add_patch(
pyplot.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color=color, fill=False,
linewidth=2))
current_axis.text(xmin, ymin, label, size='x-large', color='white',
bbox={'facecolor': color, 'alpha': 1.0})
pyplot.savefig(f"{output_path}/{custom_string}bboxes{str(i).zfill(nr_digits)}.png")
pyplot.close(figure)
return class_id, xmin, ymin, xmax, ymax