Changed save_ssd_train_images func to work with filenames as well

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-15 13:25:47 +02:00
parent 5aac93f444
commit 2f03be42dd

View File

@ -23,6 +23,7 @@ Functions:
saves the first batch of SSD train images with overlaid ground truth bounding boxes saves the first batch of SSD train images with overlaid ground truth bounding boxes
""" """
import os import os
from typing import Union, Sequence
import math import math
import numpy as np import numpy as np
@ -30,7 +31,7 @@ from matplotlib import pyplot
from PIL import Image from PIL import Image
def save_ssd_train_images(images: np.ndarray, labels: np.ndarray, def save_ssd_train_images(images: Union[np.ndarray, Sequence[str]], labels: np.ndarray,
output_path: str, coco_path: str, output_path: str, coco_path: str,
image_size: int, get_coco_cat_maps_func: callable, image_size: int, get_coco_cat_maps_func: callable,
custom_string: str = None) -> None: custom_string: str = None) -> None:
@ -40,7 +41,7 @@ def save_ssd_train_images(images: np.ndarray, labels: np.ndarray,
The images are saved both in a raw version and with bounding boxes printed on them. The images are saved both in a raw version and with bounding boxes printed on them.
Args: Args:
images: a NumPy array of images images: a NumPy array of images or a list of filenames
labels: a NumPy array of labels labels: a NumPy array of labels
output_path: path to save the images in output_path: path to save the images in
coco_path: path to the COCO data set coco_path: path to the COCO data set
@ -60,6 +61,10 @@ def save_ssd_train_images(images: np.ndarray, labels: np.ndarray,
for i, train_image in enumerate(images): for i, train_image in enumerate(images):
instances = labels[i] instances = labels[i]
if type(train_image) is str:
with Image.open(train_image) as _image:
image = _image
else:
image = Image.fromarray(train_image) image = Image.fromarray(train_image)
image.save(f"{output_path}/" image.save(f"{output_path}/"
f"{custom_string}train_image{str(i).zfill(nr_digits)}.png") f"{custom_string}train_image{str(i).zfill(nr_digits)}.png")