diff --git a/streamauc/streaming_metrics.py b/streamauc/streaming_metrics.py index 072d159..6f6113e 100644 --- a/streamauc/streaming_metrics.py +++ b/streamauc/streaming_metrics.py @@ -149,10 +149,8 @@ def update( f"Unknown shape of y_true: {y_true.shape}," f"must be squeezable to either [-1, num_classes] or [-1]." ) - if y_true.ndim==2 and np.any( y_true.sum(-1)!=1): - raise ValueError( - "The provided one-hot encoding is invalid." - ) + if y_true.ndim == 2 and np.any(y_true.sum(-1) != 1): + raise ValueError("The provided one-hot encoding is invalid.") if y_score.ndim > 2: raise ValueError( f"Unknown shape of y_true: {y_true.shape}," @@ -173,7 +171,9 @@ def update( y_onehot = np.eye(self.num_classes, dtype=int)[y_true] # use numpy broadcasting to get predictions - pred_pos = y_score[np.newaxis, ...] >= self.thresholds.reshape(-1,1,1) + pred_pos = y_score[np.newaxis, ...] >= self.thresholds.reshape( + -1, 1, 1 + ) is_pos = y_onehot[np.newaxis, ...] # sum over the minibatch samples