Moved code into separate functions
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -22,10 +22,12 @@ Functions:
|
|||||||
save_ssd_train_images(...):
|
save_ssd_train_images(...):
|
||||||
saves the first batch of SSD train images with overlaid ground truth bounding boxes
|
saves the first batch of SSD train images with overlaid ground truth bounding boxes
|
||||||
"""
|
"""
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Tuple
|
||||||
from typing import Union, Sequence
|
from typing import Union, Sequence
|
||||||
|
|
||||||
import cv2
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from matplotlib import pyplot
|
from matplotlib import pyplot
|
||||||
@ -68,13 +70,58 @@ def save_ssd_train_images(images: Union[np.ndarray, Sequence[str]], labels: np.n
|
|||||||
image = Image.fromarray(train_image)
|
image = Image.fromarray(train_image)
|
||||||
image.save(f"{output_path}/"
|
image.save(f"{output_path}/"
|
||||||
f"{custom_string}train_image{str(i).zfill(nr_digits)}.png")
|
f"{custom_string}train_image{str(i).zfill(nr_digits)}.png")
|
||||||
|
figure_filename = f"{output_path}/{custom_string}bboxes{str(i).zfill(nr_digits)}.png"
|
||||||
|
|
||||||
|
_draw_bbox_image(image=image,
|
||||||
|
filename=figure_filename,
|
||||||
|
draw_func=functools.partial(_draw_bboxes,
|
||||||
|
image_size=image_size,
|
||||||
|
classes_to_names=classes_to_names),
|
||||||
|
drawables=[
|
||||||
|
(colors, instances)
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _draw_bbox_image(image: Image,
|
||||||
|
filename: str,
|
||||||
|
draw_func: callable,
|
||||||
|
drawables: Sequence[Tuple[Sequence, Sequence[np.ndarray]]]):
|
||||||
figure = pyplot.figure(figsize=(6.4, 4.8))
|
figure = pyplot.figure(figsize=(6.4, 4.8))
|
||||||
pyplot.imshow(image)
|
pyplot.imshow(image)
|
||||||
|
|
||||||
current_axis = pyplot.gca()
|
current_axis = pyplot.gca()
|
||||||
|
for colors, instances in drawables:
|
||||||
|
draw_func(instances=instances,
|
||||||
|
axis=current_axis,
|
||||||
|
colors=colors)
|
||||||
|
|
||||||
|
pyplot.savefig(filename)
|
||||||
|
pyplot.close(figure)
|
||||||
|
|
||||||
|
|
||||||
|
def _draw_bboxes(instances: Sequence[np.ndarray], axis: pyplot.Axes,
|
||||||
|
image_size: int,
|
||||||
|
colors: Sequence,
|
||||||
|
classes_to_names: Dict[int, str]) -> None:
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
|
if not len(instance):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
class_id, xmin, ymin, xmax, ymax = _get_bbox_info(instance, image_size)
|
||||||
|
|
||||||
|
if class_id == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
color = colors[class_id]
|
||||||
|
label = f"{classes_to_names[class_id]}"
|
||||||
|
axis.add_patch(
|
||||||
|
pyplot.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color=color, fill=False,
|
||||||
|
linewidth=2))
|
||||||
|
axis.text(xmin, ymin, label, size='x-large', color='white',
|
||||||
|
bbox={'facecolor': color, 'alpha': 1.0})
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bbox_info(instance: np.ndarray, image_size: int) -> Tuple[int, float, float, float, float]:
|
||||||
if len(instance) == 5: # ground truth
|
if len(instance) == 5: # ground truth
|
||||||
class_id = int(instance[0])
|
class_id = int(instance[0])
|
||||||
xmin = instance[1]
|
xmin = instance[1]
|
||||||
@ -93,8 +140,6 @@ def save_ssd_train_images(images: Union[np.ndarray, Sequence[str]], labels: np.n
|
|||||||
ymin = instance[3]
|
ymin = instance[3]
|
||||||
xmax = instance[4]
|
xmax = instance[4]
|
||||||
ymax = instance[5]
|
ymax = instance[5]
|
||||||
elif not len(instance):
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
instance = np.copy(instance)
|
instance = np.copy(instance)
|
||||||
class_id = np.argmax(instance[:-12], axis=0)
|
class_id = np.argmax(instance[:-12], axis=0)
|
||||||
@ -109,15 +154,4 @@ def save_ssd_train_images(images: Union[np.ndarray, Sequence[str]], labels: np.n
|
|||||||
xmax = instance[-10]
|
xmax = instance[-10]
|
||||||
ymax = instance[-9]
|
ymax = instance[-9]
|
||||||
|
|
||||||
if class_id == 0:
|
return class_id, xmin, ymin, xmax, ymax
|
||||||
continue
|
|
||||||
|
|
||||||
color = colors[class_id]
|
|
||||||
label = f"{classes_to_names[class_id]}"
|
|
||||||
current_axis.add_patch(
|
|
||||||
pyplot.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color=color, fill=False,
|
|
||||||
linewidth=2))
|
|
||||||
current_axis.text(xmin, ymin, label, size='x-large', color='white',
|
|
||||||
bbox={'facecolor': color, 'alpha': 1.0})
|
|
||||||
pyplot.savefig(f"{output_path}/{custom_string}bboxes{str(i).zfill(nr_digits)}.png")
|
|
||||||
pyplot.close(figure)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user