-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmetrics.py
36 lines (32 loc) · 982 Bytes
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from torch import Tensor
from torchmetrics.classification.stat_scores import BinaryStatScores
from torchmetrics.utilities.compute import _safe_divide
class BinaryExpectedCost(BinaryStatScores):
is_differentiable = False
higher_is_better = False
full_state_update = False
def __init__(
self,
ctp: float = 0.,
cfp: float = 1.,
cfn: float = 5.,
ctn: float = 0.
):
"""
Args:
ctp: Cost of true positive
cfp: Cost of false positive
cfn: Cost of false negative
ctn: Cost of true negative
"""
super().__init__()
self.ctp = ctp
self.cfp = cfp
self.cfn = cfn
self.ctn = ctn
def compute(self) -> Tensor:
tp, fp, tn, fn = self._final_state()
return _safe_divide(
self.ctp * tp + self.cfp * fp + self.cfn * fn + self.ctn * tn,
tp + tn + fp + fn
)