Fixed encoding of labels

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-06-05 12:35:08 +02:00
parent eadd4a9273
commit efda35adb7

View File

@ -456,7 +456,12 @@ def _train_one_epoch(epoch: int,
# go through data set
for x, y in dataset:
encoded_ground_truth = input_encoder(y)
labels = []
for i in range(y.shape[0]):
image_labels = np.asarray(y[i])
image_labels = image_labels[image_labels[:, 0] != -1]
labels.append(image_labels)
encoded_ground_truth = input_encoder(labels)
ssd_train_loss = _train_ssd_step(ssd=ssd,
optimizer=ssd_optimizer,
inputs=x,