Added application of inverse transforms and removal of dummy predictions

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-05-15 14:30:53 +02:00
parent dcee6643fe
commit 9f5bea9c88

View File

@ -35,6 +35,7 @@ Functions:
train(...): trains the SSD/DropoutSSD on a given data set
"""
import os
import pickle
import time
from typing import Dict
from typing import Optional
@ -191,28 +192,53 @@ def _predict_one_epoch(dataset: tf.data.Dataset,
if use_dropout:
for _ in range(forward_passes_per_image):
result = np.array(ssd(inputs))
result_filtered = []
# iterate over result of images
for i in range(result.shape[0]):
# apply inverse transformations to predicted bounding box coordinates
# filter out dummy all-zero results
x_reverse = labels[i, 0, 5]
y_reverse = labels[i, 0, 6]
filtered = result[i][result[i, :, 0] != 0]
filtered[:, 2] *= x_reverse
filtered[:, 4] *= x_reverse
filtered[:, 3] *= y_reverse
filtered[:, 5] *= y_reverse
result_filtered.append(filtered)
result = result_filtered
decoded_predictions_batch.append(result)
del result
else:
result = np.array(ssd(inputs))
result_filtered = []
# iterate over result of images
for i in range(result.shape[0]):
# apply inverse transformations to predicted bounding box coordinates
# filter out dummy all-zero results
x_reverse = labels[i, 0, 5]
y_reverse = labels[i, 0, 6]
filtered = result[i][result[i, :, 0] != 0]
filtered[:, 2] *= x_reverse
filtered[:, 4] *= x_reverse
filtered[:, 3] *= y_reverse
filtered[:, 5] *= y_reverse
result_filtered.append(filtered)
result = result_filtered
decoded_predictions_batch.append(result)
del result
# save predictions batch-wise to prevent memory problems
if nr_digits is not None:
counter_str = str(counter).zfill(nr_digits)
filename = f"{output_file}-{counter_str}.npy"
label_filename = f"{label_output_file}-{counter_str}.npy"
filename = f"{output_file}-{counter_str}.bin"
label_filename = f"{label_output_file}-{counter_str}.bin"
else:
filename = f"{output_file}-{counter:d}.npy"
label_filename = f"{label_output_file}-{counter:d}.npy"
filename = f"{output_file}-{counter:d}.bin"
label_filename = f"{label_output_file}-{counter:d}.bin"
with open(filename, 'wb') as file, open(label_filename, 'wb') as label_file:
decoded_predictions_batch_np = np.array(decoded_predictions_batch)
del decoded_predictions_batch
np.save(file, decoded_predictions_batch_np, allow_pickle=False, fix_imports=False)
del decoded_predictions_batch_np
np.save(label_file, labels, allow_pickle=False, fix_imports=False)
pickle.dump(decoded_predictions_batch, file)
pickle.dump(labels, label_file)
counter += 1