Removed guard clause and use 0 as precision/recall if no class predictions
Signed-off-by: Jim Martens <github@2martens.de>
This commit is contained in:
@ -305,7 +305,7 @@ def get_precision_recall(number_gt_per_class: np.ndarray,
|
||||
# Iterate over all classes.
|
||||
for class_id in range(1, nr_classes + 1):
|
||||
|
||||
if number_gt_per_class[class_id] == 0 or cumulative_true_positives[class_id].shape[0] == 0:
|
||||
if number_gt_per_class[class_id] == 0:
|
||||
cumulative_precisions.append([])
|
||||
cumulative_recalls.append([])
|
||||
continue
|
||||
@ -323,8 +323,10 @@ def get_precision_recall(number_gt_per_class: np.ndarray,
|
||||
|
||||
diff_to_largest_class = cumulative_precision_micro.shape[0] - cumulative_precision.shape[0]
|
||||
if diff_to_largest_class:
|
||||
repeated_last_precision = np.tile(cumulative_precision[-1], diff_to_largest_class)
|
||||
repeated_last_recall = np.tile(cumulative_recall[-1], diff_to_largest_class)
|
||||
highest_precision = cumulative_precision[-1] if cumulative_precision.shape[0] else 0
|
||||
highest_recall = cumulative_recall[-1] if cumulative_recall.shape[0] else 0
|
||||
repeated_last_precision = np.tile(highest_precision, diff_to_largest_class)
|
||||
repeated_last_recall = np.tile(highest_recall, diff_to_largest_class)
|
||||
extended_precision = np.concatenate((cumulative_precision, repeated_last_precision))
|
||||
extended_recall = np.concatenate((cumulative_recall, repeated_last_recall))
|
||||
cumulative_precision_macro += extended_precision
|
||||
|
||||
Reference in New Issue
Block a user