From 010fc726f38cd7f9a7fc7ba494f07965f20a6d9a Mon Sep 17 00:00:00 2001 From: Fabricio Arend Torres <9096900+FabricioArendTorres@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:03:57 +0200 Subject: [PATCH] feat: Implemented vectorized update of confusion matrices --- streamauc/streaming_metrics.py | 51 ++++++++------ tests/test_streaming_metrics.py | 119 +++++++++++++++++++++++++++++++- 2 files changed, 144 insertions(+), 26 deletions(-) diff --git a/streamauc/streaming_metrics.py b/streamauc/streaming_metrics.py index 5f413f6..072d159 100644 --- a/streamauc/streaming_metrics.py +++ b/streamauc/streaming_metrics.py @@ -47,19 +47,19 @@ def _validate_thresholds( class StreamingMetrics: """ - Class for computing metrics in a minibatch-wise, iterative, fashion. + Class for keeping track of metrics for many thresholds in a + minibatch-wise, iterative, fashion. Parameters ---------- num_thresholds : int, optional Number of thresholds to evaluate the curve. Default is 200. - curve_type : str, optional - Type of curve to compute, either "ROC" or "PR". Default is "PR". num_classes : int Number of classes in the multiclass setting. Must be >= 2. thresholds : list of float, optional - List of specific thresholds to evaluate the curve. - + List of specific thresholds to evaluate the metrics at. + A probability >= threshold is defined as a positive prediction for + the respective class. """ def __init__( @@ -140,7 +140,7 @@ def update( If the shapes of `y_true` and `y_pred` do not match. """ - y_true = np.squeeze(y_true) + y_true = np.squeeze(y_true).astype(int) y_score = np.squeeze(y_score) if check_inputs: @@ -149,6 +149,10 @@ def update( f"Unknown shape of y_true: {y_true.shape}," f"must be squeezable to either [-1, num_classes] or [-1]." ) + if y_true.ndim==2 and np.any( y_true.sum(-1)!=1): + raise ValueError( + "The provided one-hot encoding is invalid." + ) if y_score.ndim > 2: raise ValueError( f"Unknown shape of y_true: {y_true.shape}," @@ -164,24 +168,25 @@ def update( raise ValueError(f"Invalid shape of y_pred: {y_score.shape}") if y_true.ndim == 2 and y_true.shape[1] == self.num_classes: - y_true_argmax = np.argmax(y_true, -1) + y_onehot = y_true else: - y_true_argmax = y_true - - for threshold_idx, threshold in enumerate(self.thresholds): - for class_idx in range(self.num_classes): - pred_pos = y_score[:, class_idx] >= threshold - is_pos = y_true_argmax == class_idx - - tp = np.sum(pred_pos & is_pos) - fp = np.sum(pred_pos & (~is_pos)) - fn = np.sum((~pred_pos) & (is_pos)) - tn = np.sum((~pred_pos) & (~is_pos)) - - self._confusion_matrix[threshold_idx, class_idx, 0, 0] += tp - self._confusion_matrix[threshold_idx, class_idx, 1, 0] += fp - self._confusion_matrix[threshold_idx, class_idx, 1, 1] += tn - self._confusion_matrix[threshold_idx, class_idx, 0, 1] += fn + y_onehot = np.eye(self.num_classes, dtype=int)[y_true] + + # use numpy broadcasting to get predictions + pred_pos = y_score[np.newaxis, ...] >= self.thresholds.reshape(-1,1,1) + is_pos = y_onehot[np.newaxis, ...] + + # sum over the minibatch samples + tp = np.sum(pred_pos & is_pos, 1) + fp = np.sum(pred_pos & (~is_pos), 1) + fn = np.sum((~pred_pos) & (is_pos), 1) + tn = np.sum((~pred_pos) & (~is_pos), 1) + + # update confusion matrix entry + self._confusion_matrix[..., 0, 0] += tp + self._confusion_matrix[..., 1, 0] += fp + self._confusion_matrix[..., 1, 1] += tn + self._confusion_matrix[..., 0, 1] += fn def _total(self) -> np.ndarray: """ diff --git a/tests/test_streaming_metrics.py b/tests/test_streaming_metrics.py index c266e0a..e97d9b0 100644 --- a/tests/test_streaming_metrics.py +++ b/tests/test_streaming_metrics.py @@ -4,7 +4,7 @@ from streamauc.metrics import f1_score, tpr, fpr import numpy as np -from sklearn.datasets import load_iris +from sklearn.datasets import load_iris, load_breast_cancer from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn import metrics @@ -83,7 +83,8 @@ def test_reset(self): curve.confusion_matrix, expected_empty_confm ) - y_true = np.random.randint(0, 2, (10, curve.num_classes)) + y_true = np.random.randint(0, 2, (10,)) + y_true = np.eye(curve.num_classes)[y_true] y_pred = np.random.random((10, curve.num_classes)) y_pred = y_pred / y_pred.sum(-1, keepdims=True) @@ -132,7 +133,8 @@ def test_invalid_input(self): curve.update(y_true=y_true, y_score=y_pred) # should not throw any errors - y_true = np.random.randint(0, 2, (10, curve.num_classes, 1, 1, 1, 1)) + y_true = np.random.randint(0, 2, (10,)) + y_true = np.eye(curve.num_classes)[y_true][..., np.newaxis, np.newaxis] y_pred = np.random.randint(0, 2, (10, curve.num_classes)) curve.update(y_true=y_true, y_score=y_pred) @@ -250,6 +252,117 @@ def test_sklearn(self): class TestStreamingMetrics(unittest.TestCase): + def setUp(self): + cancer_ds = load_breast_cancer() + X, y = cancer_ds.data, cancer_ds.target + + random_state = np.random.RandomState(0) + n_samples, n_features = X.shape + X = np.concatenate( + [X, random_state.randn(n_samples, 200 * n_features)], axis=1 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.5, stratify=y, random_state=0 + ) + + classifier = LogisticRegression(max_iter=1000) + self.y_score = classifier.fit(X_train, y_train).predict_proba(X_test) + + self.y_test = y_test + + thresholds = np.unique(self.y_score) + self.dim = 2 + self.curve = StreamingMetrics( + thresholds=thresholds, + num_classes=self.dim, + ) + + # check that multiple updates have the same effect as one big.. + half = self.y_test.shape[0] // 2 + self.curve.update(self.y_test[:half], self.y_score[:half]) + self.curve.update(self.y_test[half:], self.y_score[half:]) + + def test_total(self): + new_curve = StreamingMetrics( + num_thresholds=100, + num_classes=self.dim, + ) + + self.assertEqual(new_curve._total().shape, (100, self.dim)) + np.testing.assert_allclose( + new_curve._total(), np.zeros_like(new_curve._total()) + ) + + new_curve.update(self.y_test, self.y_score) + new_curve.update(self.y_test, self.y_score) + self.assertEqual(new_curve._total().shape, (100, self.dim)) + + np.testing.assert_allclose( + new_curve._total(), + 2 * self.y_test.shape[0] * np.ones_like(new_curve._total()), + ) + + def test_confusion_matrix(self): + for class_idx in range(self.dim): + y_true = self.y_test == class_idx + + for threshold in self.curve.thresholds: + y_pred = self.y_score[:, class_idx] >= threshold + + # sklearn has the confusion matrix flipped + confm_ref = np.flip(confusion_matrix(y_true, y_pred)) + + computed_confm = self.curve.confusion_matrix[ + self.curve.thresholds.tolist().index(threshold), class_idx + ] + np.testing.assert_array_equal(computed_confm, confm_ref) + + def test_precision_recall_curve(self): + for class_idx in range(self.dim): + precision, recall, thresholds = sk_precision_recall_curve( + self.y_test == class_idx, self.y_score[:, class_idx] + ) + + new_curve = StreamingMetrics( + thresholds=thresholds, + num_classes=self.dim, + ) + + # check that multiple updates have the same effect as one big.. + half = self.y_test.shape[0] // 2 + new_curve.update(self.y_test[:half], self.y_score[:half]) + new_curve.update(self.y_test[half:], self.y_score[half:]) + stream_prec, stream_recall, stream_thresholds = ( + new_curve.precision_recall_curve(class_index=class_idx) + ) + np.testing.assert_almost_equal(stream_thresholds[1:], thresholds) + np.testing.assert_almost_equal(precision[:1], stream_prec[:1]) + np.testing.assert_almost_equal(recall, stream_recall) + + def test_roc_curve(self): + for class_idx in range(self.dim): + _fpr, _tpr, thresholds = sk_roc_curve( + self.y_test == class_idx, self.y_score[:, class_idx] + ) + + new_curve = StreamingMetrics( + thresholds=thresholds[1:], + num_classes=self.dim, + ) + + # ensure that multiple updates have the same effect as one big.. + half = self.y_test.shape[0] // 2 + new_curve.update(self.y_test[:half], self.y_score[:half]) + new_curve.update(self.y_test[half:], self.y_score[half:]) + + streaming_fpr, streaming_tpr, _thr = new_curve.roc_curve( + class_index=class_idx + ) + np.testing.assert_almost_equal(_fpr, streaming_fpr[:-1]) + np.testing.assert_almost_equal(_tpr, streaming_tpr[:-1]) + + +class TestStreamingMetricsBinary(unittest.TestCase): def setUp(self): iris = load_iris() X, y = iris.data, iris.target