Skip to content

Commit 4150edb

Browse files
authored
[WB-3854]: Out of range error when class names not included as argument (wandb#1479)
* fix conf mat code for now class names, add test * staceys neat union cardinality fix * handle edge case where inds are weird
1 parent b1efe21 commit 4150edb

2 files changed

Lines changed: 16 additions & 3 deletions

File tree

standalone_tests/tweets.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,11 @@
5353
class_ind_map[class_name] = i
5454
y_pred_inds = [class_ind_map[class_name] for class_name in y_pred]
5555
y_true_inds = [class_ind_map[class_name] for class_name in y_test]
56-
56+
# test workflow with classes
5757
wandb.log({'conf_mat': wandb.plot.confusion_matrix(y_pred_inds, y_true_inds, nb.classes_)})
58+
# test workflow without classes
59+
wandb.log({'conf_mat_noclass': wandb.plot.confusion_matrix(y_pred_inds, y_true_inds)})
60+
# test workflow with multiples of inds
61+
y_pred_mult = [y_pred_ind*5 for y_pred_ind in y_pred_inds]
62+
y_true_mult = [y_true_ind*5 for y_true_ind in y_true_inds]
63+
wandb.log({'conf_mat_noclass_mult': wandb.plot.confusion_matrix(y_pred_mult, y_true_mult)})

wandb/plot/confusion_matrix.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,26 @@ def confusion_matrix(preds=None, y_true=None, class_names=None):
3030
), "Number of predictions and label indices must match"
3131
if class_names is not None:
3232
n_classes = len(class_names)
33+
class_inds = set(preds).union(set(y_true))
3334
assert max(preds) <= len(
3435
class_names
3536
), "Higher predicted index than number of classes"
3637
assert max(y_true) <= len(
3738
class_names
3839
), "Higher label class index than number of classes"
3940
else:
40-
n_classes = max(max(preds), max(y_true))
41+
class_inds = set(preds).union(set(y_true))
42+
n_classes = len(class_inds)
4143
class_names = ["Class_{}".format(i) for i in range(1, n_classes + 1)]
4244

45+
# get mapping of inds to class index in case user has weird prediction indices
46+
class_mapping = {}
47+
for i, val in enumerate(sorted(list(class_inds))):
48+
class_mapping[val] = i
49+
4350
counts = np.zeros((n_classes, n_classes))
4451
for i in range(len(preds)):
45-
counts[y_true[i], preds[i]] += 1
52+
counts[class_mapping[y_true[i]], class_mapping[preds[i]]] += 1
4653

4754
data = []
4855
for i in range(n_classes):

0 commit comments

Comments
 (0)