Skip to content

Commit

Permalink
chore: fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
FabricioArendTorres committed Jul 19, 2024
1 parent 6761493 commit 48392d5
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions streamauc/streaming_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,8 @@ def update(
f"Unknown shape of y_true: {y_true.shape},"
f"must be squeezable to either [-1, num_classes] or [-1]."
)
if y_true.ndim==2 and np.any( y_true.sum(-1)!=1):
raise ValueError(
"The provided one-hot encoding is invalid."
)
if y_true.ndim == 2 and np.any(y_true.sum(-1) != 1):
raise ValueError("The provided one-hot encoding is invalid.")
if y_score.ndim > 2:
raise ValueError(
f"Unknown shape of y_true: {y_true.shape},"
Expand All @@ -173,7 +171,9 @@ def update(
y_onehot = np.eye(self.num_classes, dtype=int)[y_true]

# use numpy broadcasting to get predictions
pred_pos = y_score[np.newaxis, ...] >= self.thresholds.reshape(-1,1,1)
pred_pos = y_score[np.newaxis, ...] >= self.thresholds.reshape(
-1, 1, 1
)
is_pos = y_onehot[np.newaxis, ...]

# sum over the minibatch samples
Expand Down

0 comments on commit 48392d5

Please sign in to comment.