From 48392d524ab52db28fd28e993ddcca410e3409ea Mon Sep 17 00:00:00 2001 From: Fabricio Arend Torres <9096900+FabricioArendTorres@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:06:49 +0200 Subject: [PATCH] chore: fmt --- streamauc/streaming_metrics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/streamauc/streaming_metrics.py b/streamauc/streaming_metrics.py index 072d159..6f6113e 100644 --- a/streamauc/streaming_metrics.py +++ b/streamauc/streaming_metrics.py @@ -149,10 +149,8 @@ def update( f"Unknown shape of y_true: {y_true.shape}," f"must be squeezable to either [-1, num_classes] or [-1]." ) - if y_true.ndim==2 and np.any( y_true.sum(-1)!=1): - raise ValueError( - "The provided one-hot encoding is invalid." - ) + if y_true.ndim == 2 and np.any(y_true.sum(-1) != 1): + raise ValueError("The provided one-hot encoding is invalid.") if y_score.ndim > 2: raise ValueError( f"Unknown shape of y_true: {y_true.shape}," @@ -173,7 +171,9 @@ def update( y_onehot = np.eye(self.num_classes, dtype=int)[y_true] # use numpy broadcasting to get predictions - pred_pos = y_score[np.newaxis, ...] >= self.thresholds.reshape(-1,1,1) + pred_pos = y_score[np.newaxis, ...] >= self.thresholds.reshape( + -1, 1, 1 + ) is_pos = y_onehot[np.newaxis, ...] # sum over the minibatch samples