@@ -76,7 +76,22 @@ def __init__(
76
76
self .preds_max = preds_max
77
77
self .actual_cutpoints = actual_cutpoints
78
78
self .preds_cutpoints = preds_cutpoints
79
- self .reset_state ()
79
+ self .actual_cuts = tf .linspace (
80
+ tf .cast (self .actual_min , tf .float32 ),
81
+ tf .cast (self .actual_max , tf .float32 ),
82
+ self .actual_cutpoints - 1 ,
83
+ )
84
+ self .preds_cuts = tf .linspace (
85
+ tf .cast (self .preds_min , tf .float32 ),
86
+ tf .cast (self .preds_max , tf .float32 ),
87
+ self .preds_cutpoints - 1 ,
88
+ )
89
+ self .m = self .add_weight (
90
+ "m" , (self .actual_cutpoints , self .preds_cutpoints ), dtype = tf .int64
91
+ )
92
+ self .nrow = self .add_weight ("nrow" , (self .actual_cutpoints ), dtype = tf .int64 )
93
+ self .ncol = self .add_weight ("ncol" , (self .preds_cutpoints ), dtype = tf .int64 )
94
+ self .n = self .add_weight ("n" , (), dtype = tf .int64 )
80
95
81
96
def update_state (self , y_true , y_pred , sample_weight = None ):
82
97
"""Accumulates ranks.
@@ -89,75 +104,69 @@ def update_state(self, y_true, y_pred, sample_weight=None):
89
104
Returns:
90
105
Update op.
91
106
"""
92
- if y_true .shape and y_true .shape [0 ]:
93
- i = tf .searchsorted (
94
- self .actual_cuts ,
95
- tf .cast (tf .reshape (y_true , - 1 ), self .actual_cuts .dtype ),
107
+ i = tf .searchsorted (
108
+ self .actual_cuts ,
109
+ tf .cast (tf .reshape (y_true , [- 1 ]), self .actual_cuts .dtype ),
110
+ )
111
+ j = tf .searchsorted (
112
+ self .preds_cuts , tf .cast (tf .reshape (y_pred , [- 1 ]), self .preds_cuts .dtype )
113
+ )
114
+
115
+ m = tf .sparse .from_dense (self .m )
116
+ nrow = tf .sparse .from_dense (self .nrow )
117
+ ncol = tf .sparse .from_dense (self .ncol )
118
+
119
+ k = 0
120
+ while k < tf .shape (i )[0 ]:
121
+ m = tf .sparse .add (
122
+ m ,
123
+ tf .SparseTensor (
124
+ [[i [k ], j [k ]]],
125
+ tf .cast ([1 ], dtype = m .dtype ),
126
+ self .m .shape ,
127
+ ),
96
128
)
97
- j = tf .searchsorted (
98
- self .preds_cuts , tf .cast (tf .reshape (y_pred , - 1 ), self .preds_cuts .dtype )
129
+ nrow = tf .sparse .add (
130
+ nrow ,
131
+ tf .SparseTensor (
132
+ [[i [k ]]],
133
+ tf .cast ([1 ], dtype = nrow .dtype ),
134
+ self .nrow .shape ,
135
+ ),
99
136
)
100
-
101
- def body (k , n , m , nrow , ncol ):
102
- return (
103
- k + 1 ,
104
- n + 1 ,
105
- tf .sparse .add (
106
- m ,
107
- tf .SparseTensor (
108
- [[i [k ], j [k ]]],
109
- tf .cast ([1 ], dtype = self .m .dtype ),
110
- self .m .shape ,
111
- ),
112
- ),
113
- tf .sparse .add (
114
- nrow ,
115
- tf .SparseTensor (
116
- [[i [k ]]],
117
- tf .cast ([1 ], dtype = self .nrow .dtype ),
118
- self .nrow .shape ,
119
- ),
120
- ),
121
- tf .sparse .add (
122
- ncol ,
123
- tf .SparseTensor (
124
- [[j [k ]]],
125
- tf .cast ([1 ], dtype = self .ncol .dtype ),
126
- self .ncol .shape ,
127
- ),
128
- ),
129
- )
130
-
131
- _ , self .n , self .m , self .nrow , self .ncol = tf .while_loop (
132
- lambda k , n , m , nrow , ncol : k < i .shape [0 ],
133
- body = body ,
134
- loop_vars = (0 , self .n , self .m , self .nrow , self .ncol ),
137
+ ncol = tf .sparse .add (
138
+ ncol ,
139
+ tf .SparseTensor (
140
+ [[j [k ]]],
141
+ tf .cast ([1 ], dtype = ncol .dtype ),
142
+ self .ncol .shape ,
143
+ ),
135
144
)
145
+ k += 1
146
+
147
+ self .n .assign_add (tf .cast (k , tf .int64 ))
148
+ self .m .assign (tf .sparse .to_dense (m ))
149
+ self .nrow .assign (tf .sparse .to_dense (nrow ))
150
+ self .ncol .assign (tf .sparse .to_dense (ncol ))
136
151
137
152
def result (self ):
138
- m_dense = tf .sparse .to_dense (tf .cast (self .m , tf .float32 ))
139
- n_cap = tf .cumsum (
140
- tf .cumsum (
141
- tf .slice (tf .pad (m_dense , [[1 , 0 ], [1 , 0 ]]), [0 , 0 ], self .m .shape ),
142
- axis = 0 ,
143
- ),
144
- axis = 1 ,
145
- )
153
+ m = tf .cast (self .m , tf .float32 )
154
+ n_cap = tf .cumsum (tf .cumsum (m , axis = 0 ), axis = 1 )
146
155
# Number of concordant pairs.
147
- p = tf .math .reduce_sum (tf .multiply (n_cap , m_dense ))
148
- sum_m_squard = tf .math .reduce_sum (tf .math .square (m_dense ))
156
+ p = tf .math .reduce_sum (tf .multiply (n_cap [: - 1 , : - 1 ], m [ 1 :, 1 :] ))
157
+ sum_m_squard = tf .math .reduce_sum (tf .math .square (m ))
149
158
# Ties in x.
150
159
t = (
151
- tf .math .reduce_sum (tf .math .square (tf . sparse . to_dense ( self .nrow )))
160
+ tf .cast ( tf . math .reduce_sum (tf .math .square (self .nrow )), tf . float32 )
152
161
- sum_m_squard
153
162
) / 2.0
154
163
# Ties in y.
155
164
u = (
156
- tf .math .reduce_sum (tf .math .square (tf . sparse . to_dense ( self .ncol )))
165
+ tf .cast ( tf . math .reduce_sum (tf .math .square (self .ncol )), tf . float32 )
157
166
- sum_m_squard
158
167
) / 2.0
159
168
# Ties in both.
160
- b = tf .math .reduce_sum (tf .multiply (m_dense , (m_dense - 1.0 ))) / 2.0
169
+ b = tf .math .reduce_sum (tf .multiply (m , (m - 1.0 ))) / 2.0
161
170
# Number of discordant pairs.
162
171
n = tf .cast (self .n , tf .float32 )
163
172
q = (n - 1.0 ) * n / 2.0 - p - t - u - b
@@ -179,28 +188,11 @@ def get_config(self):
179
188
180
189
def reset_state (self ):
181
190
"""Resets all of the metric state variables."""
182
- self .actual_cuts = tf .linspace (
183
- tf .cast (self .actual_min , tf .float32 ),
184
- tf .cast (self .actual_max , tf .float32 ),
185
- self .actual_cutpoints - 1 ,
186
- )
187
- self .preds_cuts = tf .linspace (
188
- tf .cast (self .preds_min , tf .float32 ),
189
- tf .cast (self .preds_max , tf .float32 ),
190
- self .preds_cutpoints - 1 ,
191
- )
192
- self .m = tf .SparseTensor (
193
- tf .zeros ((0 , 2 ), tf .int64 ),
194
- [],
195
- [self .actual_cutpoints , self .preds_cutpoints ],
196
- )
197
- self .nrow = tf .SparseTensor (
198
- tf .zeros ((0 , 1 ), dtype = tf .int64 ), [], [self .actual_cutpoints ]
199
- )
200
- self .ncol = tf .SparseTensor (
201
- tf .zeros ((0 , 1 ), dtype = tf .int64 ), [], [self .preds_cutpoints ]
202
- )
203
- self .n = 0
191
+
192
+ self .m .assign (tf .zeros ((self .actual_cutpoints , self .preds_cutpoints ), tf .int64 ))
193
+ self .nrow .assign (tf .zeros ((self .actual_cutpoints ), tf .int64 ))
194
+ self .ncol .assign (tf .zeros ((self .preds_cutpoints ), tf .int64 ))
195
+ self .n .assign (0 )
204
196
205
197
def reset_states (self ):
206
198
# Backwards compatibility alias of `reset_state`. New classes should
0 commit comments