Skip to content

Commit

Permalink
feat: Implemented vectorized update of confusion matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
FabricioArendTorres committed Jul 19, 2024
1 parent 38494e9 commit 010fc72
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 26 deletions.
51 changes: 28 additions & 23 deletions streamauc/streaming_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ def _validate_thresholds(

class StreamingMetrics:
"""
Class for computing metrics in a minibatch-wise, iterative, fashion.
Class for keeping track of metrics for many thresholds in a
minibatch-wise, iterative, fashion.
Parameters
----------
num_thresholds : int, optional
Number of thresholds to evaluate the curve. Default is 200.
curve_type : str, optional
Type of curve to compute, either "ROC" or "PR". Default is "PR".
num_classes : int
Number of classes in the multiclass setting. Must be >= 2.
thresholds : list of float, optional
List of specific thresholds to evaluate the curve.
List of specific thresholds to evaluate the metrics at.
A probability >= threshold is defined as a positive prediction for
the respective class.
"""

def __init__(
Expand Down Expand Up @@ -140,7 +140,7 @@ def update(
If the shapes of `y_true` and `y_pred` do not match.
"""

y_true = np.squeeze(y_true)
y_true = np.squeeze(y_true).astype(int)
y_score = np.squeeze(y_score)

if check_inputs:
Expand All @@ -149,6 +149,10 @@ def update(
f"Unknown shape of y_true: {y_true.shape},"
f"must be squeezable to either [-1, num_classes] or [-1]."
)
if y_true.ndim==2 and np.any( y_true.sum(-1)!=1):
raise ValueError(
"The provided one-hot encoding is invalid."
)
if y_score.ndim > 2:
raise ValueError(
f"Unknown shape of y_true: {y_true.shape},"
Expand All @@ -164,24 +168,25 @@ def update(
raise ValueError(f"Invalid shape of y_pred: {y_score.shape}")

if y_true.ndim == 2 and y_true.shape[1] == self.num_classes:
y_true_argmax = np.argmax(y_true, -1)
y_onehot = y_true
else:
y_true_argmax = y_true

for threshold_idx, threshold in enumerate(self.thresholds):
for class_idx in range(self.num_classes):
pred_pos = y_score[:, class_idx] >= threshold
is_pos = y_true_argmax == class_idx

tp = np.sum(pred_pos & is_pos)
fp = np.sum(pred_pos & (~is_pos))
fn = np.sum((~pred_pos) & (is_pos))
tn = np.sum((~pred_pos) & (~is_pos))

self._confusion_matrix[threshold_idx, class_idx, 0, 0] += tp
self._confusion_matrix[threshold_idx, class_idx, 1, 0] += fp
self._confusion_matrix[threshold_idx, class_idx, 1, 1] += tn
self._confusion_matrix[threshold_idx, class_idx, 0, 1] += fn
y_onehot = np.eye(self.num_classes, dtype=int)[y_true]

# use numpy broadcasting to get predictions
pred_pos = y_score[np.newaxis, ...] >= self.thresholds.reshape(-1,1,1)
is_pos = y_onehot[np.newaxis, ...]

# sum over the minibatch samples
tp = np.sum(pred_pos & is_pos, 1)
fp = np.sum(pred_pos & (~is_pos), 1)
fn = np.sum((~pred_pos) & (is_pos), 1)
tn = np.sum((~pred_pos) & (~is_pos), 1)

# update confusion matrix entry
self._confusion_matrix[..., 0, 0] += tp
self._confusion_matrix[..., 1, 0] += fp
self._confusion_matrix[..., 1, 1] += tn
self._confusion_matrix[..., 0, 1] += fn

def _total(self) -> np.ndarray:
"""
Expand Down
119 changes: 116 additions & 3 deletions tests/test_streaming_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from streamauc.metrics import f1_score, tpr, fpr
import numpy as np

from sklearn.datasets import load_iris
from sklearn.datasets import load_iris, load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn import metrics
Expand Down Expand Up @@ -83,7 +83,8 @@ def test_reset(self):
curve.confusion_matrix, expected_empty_confm
)

y_true = np.random.randint(0, 2, (10, curve.num_classes))
y_true = np.random.randint(0, 2, (10,))
y_true = np.eye(curve.num_classes)[y_true]
y_pred = np.random.random((10, curve.num_classes))
y_pred = y_pred / y_pred.sum(-1, keepdims=True)

Expand Down Expand Up @@ -132,7 +133,8 @@ def test_invalid_input(self):
curve.update(y_true=y_true, y_score=y_pred)

# should not throw any errors
y_true = np.random.randint(0, 2, (10, curve.num_classes, 1, 1, 1, 1))
y_true = np.random.randint(0, 2, (10,))
y_true = np.eye(curve.num_classes)[y_true][..., np.newaxis, np.newaxis]
y_pred = np.random.randint(0, 2, (10, curve.num_classes))
curve.update(y_true=y_true, y_score=y_pred)

Expand Down Expand Up @@ -250,6 +252,117 @@ def test_sklearn(self):


class TestStreamingMetrics(unittest.TestCase):
def setUp(self):
cancer_ds = load_breast_cancer()
X, y = cancer_ds.data, cancer_ds.target

random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
X = np.concatenate(
[X, random_state.randn(n_samples, 200 * n_features)], axis=1
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.5, stratify=y, random_state=0
)

classifier = LogisticRegression(max_iter=1000)
self.y_score = classifier.fit(X_train, y_train).predict_proba(X_test)

self.y_test = y_test

thresholds = np.unique(self.y_score)
self.dim = 2
self.curve = StreamingMetrics(
thresholds=thresholds,
num_classes=self.dim,
)

# check that multiple updates have the same effect as one big..
half = self.y_test.shape[0] // 2
self.curve.update(self.y_test[:half], self.y_score[:half])
self.curve.update(self.y_test[half:], self.y_score[half:])

def test_total(self):
new_curve = StreamingMetrics(
num_thresholds=100,
num_classes=self.dim,
)

self.assertEqual(new_curve._total().shape, (100, self.dim))
np.testing.assert_allclose(
new_curve._total(), np.zeros_like(new_curve._total())
)

new_curve.update(self.y_test, self.y_score)
new_curve.update(self.y_test, self.y_score)
self.assertEqual(new_curve._total().shape, (100, self.dim))

np.testing.assert_allclose(
new_curve._total(),
2 * self.y_test.shape[0] * np.ones_like(new_curve._total()),
)

def test_confusion_matrix(self):
for class_idx in range(self.dim):
y_true = self.y_test == class_idx

for threshold in self.curve.thresholds:
y_pred = self.y_score[:, class_idx] >= threshold

# sklearn has the confusion matrix flipped
confm_ref = np.flip(confusion_matrix(y_true, y_pred))

computed_confm = self.curve.confusion_matrix[
self.curve.thresholds.tolist().index(threshold), class_idx
]
np.testing.assert_array_equal(computed_confm, confm_ref)

def test_precision_recall_curve(self):
for class_idx in range(self.dim):
precision, recall, thresholds = sk_precision_recall_curve(
self.y_test == class_idx, self.y_score[:, class_idx]
)

new_curve = StreamingMetrics(
thresholds=thresholds,
num_classes=self.dim,
)

# check that multiple updates have the same effect as one big..
half = self.y_test.shape[0] // 2
new_curve.update(self.y_test[:half], self.y_score[:half])
new_curve.update(self.y_test[half:], self.y_score[half:])
stream_prec, stream_recall, stream_thresholds = (
new_curve.precision_recall_curve(class_index=class_idx)
)
np.testing.assert_almost_equal(stream_thresholds[1:], thresholds)
np.testing.assert_almost_equal(precision[:1], stream_prec[:1])
np.testing.assert_almost_equal(recall, stream_recall)

def test_roc_curve(self):
for class_idx in range(self.dim):
_fpr, _tpr, thresholds = sk_roc_curve(
self.y_test == class_idx, self.y_score[:, class_idx]
)

new_curve = StreamingMetrics(
thresholds=thresholds[1:],
num_classes=self.dim,
)

# ensure that multiple updates have the same effect as one big..
half = self.y_test.shape[0] // 2
new_curve.update(self.y_test[:half], self.y_score[:half])
new_curve.update(self.y_test[half:], self.y_score[half:])

streaming_fpr, streaming_tpr, _thr = new_curve.roc_curve(
class_index=class_idx
)
np.testing.assert_almost_equal(_fpr, streaming_fpr[:-1])
np.testing.assert_almost_equal(_tpr, streaming_tpr[:-1])


class TestStreamingMetricsBinary(unittest.TestCase):
def setUp(self):
iris = load_iris()
X, y = iris.data, iris.target
Expand Down

0 comments on commit 010fc72

Please sign in to comment.