Made function compatible with encoded and decoded labels
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -70,13 +70,27 @@ def save_ssd_train_images(images: np.ndarray, labels: np.ndarray,
|
||||
current_axis = pyplot.gca()
|
||||
|
||||
for instance in instances:
|
||||
xmin = (instance[-12]) * image_size
|
||||
ymin = (instance[-11]) * image_size
|
||||
xmax = (instance[-10]) * image_size
|
||||
ymax = (instance[-9]) * image_size
|
||||
class_id = np.argmax(instance[:-12], axis=0)
|
||||
if len(instance) == 5:
|
||||
class_id = instance[0]
|
||||
xmin = instance[1]
|
||||
ymin = instance[2]
|
||||
xmax = instance[3]
|
||||
ymax = instance[4]
|
||||
else:
|
||||
class_id = np.argmax(instance[:-12], axis=0)
|
||||
xmin = instance[-12] + instance[-8]
|
||||
ymin = instance[-11] + instance[-7]
|
||||
xmax = instance[-10] + instance[-6]
|
||||
ymax = instance[-9] + instance[-5]
|
||||
|
||||
if class_id == 0:
|
||||
continue
|
||||
|
||||
xmin *= image_size
|
||||
ymin *= image_size
|
||||
xmax *= image_size
|
||||
ymax *= image_size
|
||||
|
||||
color = colors[class_id]
|
||||
label = f"{classes_to_names[class_id]}"
|
||||
current_axis.add_patch(
|
||||
|
||||
Reference in New Issue
Block a user