Fixed debug saving of images to work with predictions

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-11 17:12:54 +02:00
parent 8ef685dae9
commit 6f6b4b7e34

View File

@ -70,12 +70,18 @@ def save_ssd_train_images(images: np.ndarray, labels: np.ndarray,
current_axis = pyplot.gca() current_axis = pyplot.gca()
for instance in instances: for instance in instances:
if len(instance) == 5: if len(instance) == 5: # ground truth
class_id = int(instance[0]) class_id = int(instance[0])
xmin = instance[1] xmin = instance[1]
ymin = instance[2] ymin = instance[2]
xmax = instance[3] xmax = instance[3]
ymax = instance[4] ymax = instance[4]
elif len(instance) == 7: # predictions
class_id = int(instance[0])
xmin = instance[3]
ymin = instance[4]
xmax = instance[5]
ymax = instance[6]
else: else:
instance = np.asarray(instance) instance = np.asarray(instance)
class_id = np.argmax(instance[:-12], axis=0) class_id = np.argmax(instance[:-12], axis=0)