Added additional guard clause for cases with no true/false positive entries

Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
2019-07-18 14:58:06 +02:00
parent 60fb13cddf
commit 8897a65b3d

View File

@ -305,7 +305,7 @@ def get_precision_recall(number_gt_per_class: np.ndarray,
# Iterate over all classes.
for class_id in range(1, nr_classes + 1):
if number_gt_per_class[class_id] == 0:
if number_gt_per_class[class_id] == 0 or cumulative_true_positives[class_id].shape[0] == 0:
cumulative_precisions.append([])
cumulative_recalls.append([])
continue
@ -323,8 +323,6 @@ def get_precision_recall(number_gt_per_class: np.ndarray,
diff_to_largest_class = cumulative_precision_micro.shape[0] - cumulative_precision.shape[0]
if diff_to_largest_class:
print(f"Diff to largest: {diff_to_largest_class}")
print(f"Shape: {tp.shape}")
repeated_last_precision = np.tile(cumulative_precision[-1], diff_to_largest_class)
repeated_last_recall = np.tile(cumulative_recall[-1], diff_to_largest_class)
extended_precision = np.concatenate((cumulative_precision, repeated_last_precision))