Skip to content

Commit c7c40a0

Browse files
authored
Fix the Kendalls Tau metric when used in graph mode (#2739)
* - Fix the metric when used in graph mode * - Removed unnecessary padding op * - import keras -> import tensorflow.keras
1 parent b214e25 commit c7c40a0

File tree

2 files changed

+71
-78
lines changed

2 files changed

+71
-78
lines changed

tensorflow_addons/metrics/kendalls_tau.py

+69-77
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,22 @@ def __init__(
7676
self.preds_max = preds_max
7777
self.actual_cutpoints = actual_cutpoints
7878
self.preds_cutpoints = preds_cutpoints
79-
self.reset_state()
79+
self.actual_cuts = tf.linspace(
80+
tf.cast(self.actual_min, tf.float32),
81+
tf.cast(self.actual_max, tf.float32),
82+
self.actual_cutpoints - 1,
83+
)
84+
self.preds_cuts = tf.linspace(
85+
tf.cast(self.preds_min, tf.float32),
86+
tf.cast(self.preds_max, tf.float32),
87+
self.preds_cutpoints - 1,
88+
)
89+
self.m = self.add_weight(
90+
"m", (self.actual_cutpoints, self.preds_cutpoints), dtype=tf.int64
91+
)
92+
self.nrow = self.add_weight("nrow", (self.actual_cutpoints), dtype=tf.int64)
93+
self.ncol = self.add_weight("ncol", (self.preds_cutpoints), dtype=tf.int64)
94+
self.n = self.add_weight("n", (), dtype=tf.int64)
8095

8196
def update_state(self, y_true, y_pred, sample_weight=None):
8297
"""Accumulates ranks.
@@ -89,75 +104,69 @@ def update_state(self, y_true, y_pred, sample_weight=None):
89104
Returns:
90105
Update op.
91106
"""
92-
if y_true.shape and y_true.shape[0]:
93-
i = tf.searchsorted(
94-
self.actual_cuts,
95-
tf.cast(tf.reshape(y_true, -1), self.actual_cuts.dtype),
107+
i = tf.searchsorted(
108+
self.actual_cuts,
109+
tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype),
110+
)
111+
j = tf.searchsorted(
112+
self.preds_cuts, tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype)
113+
)
114+
115+
m = tf.sparse.from_dense(self.m)
116+
nrow = tf.sparse.from_dense(self.nrow)
117+
ncol = tf.sparse.from_dense(self.ncol)
118+
119+
k = 0
120+
while k < tf.shape(i)[0]:
121+
m = tf.sparse.add(
122+
m,
123+
tf.SparseTensor(
124+
[[i[k], j[k]]],
125+
tf.cast([1], dtype=m.dtype),
126+
self.m.shape,
127+
),
96128
)
97-
j = tf.searchsorted(
98-
self.preds_cuts, tf.cast(tf.reshape(y_pred, -1), self.preds_cuts.dtype)
129+
nrow = tf.sparse.add(
130+
nrow,
131+
tf.SparseTensor(
132+
[[i[k]]],
133+
tf.cast([1], dtype=nrow.dtype),
134+
self.nrow.shape,
135+
),
99136
)
100-
101-
def body(k, n, m, nrow, ncol):
102-
return (
103-
k + 1,
104-
n + 1,
105-
tf.sparse.add(
106-
m,
107-
tf.SparseTensor(
108-
[[i[k], j[k]]],
109-
tf.cast([1], dtype=self.m.dtype),
110-
self.m.shape,
111-
),
112-
),
113-
tf.sparse.add(
114-
nrow,
115-
tf.SparseTensor(
116-
[[i[k]]],
117-
tf.cast([1], dtype=self.nrow.dtype),
118-
self.nrow.shape,
119-
),
120-
),
121-
tf.sparse.add(
122-
ncol,
123-
tf.SparseTensor(
124-
[[j[k]]],
125-
tf.cast([1], dtype=self.ncol.dtype),
126-
self.ncol.shape,
127-
),
128-
),
129-
)
130-
131-
_, self.n, self.m, self.nrow, self.ncol = tf.while_loop(
132-
lambda k, n, m, nrow, ncol: k < i.shape[0],
133-
body=body,
134-
loop_vars=(0, self.n, self.m, self.nrow, self.ncol),
137+
ncol = tf.sparse.add(
138+
ncol,
139+
tf.SparseTensor(
140+
[[j[k]]],
141+
tf.cast([1], dtype=ncol.dtype),
142+
self.ncol.shape,
143+
),
135144
)
145+
k += 1
146+
147+
self.n.assign_add(tf.cast(k, tf.int64))
148+
self.m.assign(tf.sparse.to_dense(m))
149+
self.nrow.assign(tf.sparse.to_dense(nrow))
150+
self.ncol.assign(tf.sparse.to_dense(ncol))
136151

137152
def result(self):
138-
m_dense = tf.sparse.to_dense(tf.cast(self.m, tf.float32))
139-
n_cap = tf.cumsum(
140-
tf.cumsum(
141-
tf.slice(tf.pad(m_dense, [[1, 0], [1, 0]]), [0, 0], self.m.shape),
142-
axis=0,
143-
),
144-
axis=1,
145-
)
153+
m = tf.cast(self.m, tf.float32)
154+
n_cap = tf.cumsum(tf.cumsum(m, axis=0), axis=1)
146155
# Number of concordant pairs.
147-
p = tf.math.reduce_sum(tf.multiply(n_cap, m_dense))
148-
sum_m_squard = tf.math.reduce_sum(tf.math.square(m_dense))
156+
p = tf.math.reduce_sum(tf.multiply(n_cap[:-1, :-1], m[1:, 1:]))
157+
sum_m_squard = tf.math.reduce_sum(tf.math.square(m))
149158
# Ties in x.
150159
t = (
151-
tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.nrow)))
160+
tf.cast(tf.math.reduce_sum(tf.math.square(self.nrow)), tf.float32)
152161
- sum_m_squard
153162
) / 2.0
154163
# Ties in y.
155164
u = (
156-
tf.math.reduce_sum(tf.math.square(tf.sparse.to_dense(self.ncol)))
165+
tf.cast(tf.math.reduce_sum(tf.math.square(self.ncol)), tf.float32)
157166
- sum_m_squard
158167
) / 2.0
159168
# Ties in both.
160-
b = tf.math.reduce_sum(tf.multiply(m_dense, (m_dense - 1.0))) / 2.0
169+
b = tf.math.reduce_sum(tf.multiply(m, (m - 1.0))) / 2.0
161170
# Number of discordant pairs.
162171
n = tf.cast(self.n, tf.float32)
163172
q = (n - 1.0) * n / 2.0 - p - t - u - b
@@ -179,28 +188,11 @@ def get_config(self):
179188

180189
def reset_state(self):
181190
"""Resets all of the metric state variables."""
182-
self.actual_cuts = tf.linspace(
183-
tf.cast(self.actual_min, tf.float32),
184-
tf.cast(self.actual_max, tf.float32),
185-
self.actual_cutpoints - 1,
186-
)
187-
self.preds_cuts = tf.linspace(
188-
tf.cast(self.preds_min, tf.float32),
189-
tf.cast(self.preds_max, tf.float32),
190-
self.preds_cutpoints - 1,
191-
)
192-
self.m = tf.SparseTensor(
193-
tf.zeros((0, 2), tf.int64),
194-
[],
195-
[self.actual_cutpoints, self.preds_cutpoints],
196-
)
197-
self.nrow = tf.SparseTensor(
198-
tf.zeros((0, 1), dtype=tf.int64), [], [self.actual_cutpoints]
199-
)
200-
self.ncol = tf.SparseTensor(
201-
tf.zeros((0, 1), dtype=tf.int64), [], [self.preds_cutpoints]
202-
)
203-
self.n = 0
191+
192+
self.m.assign(tf.zeros((self.actual_cutpoints, self.preds_cutpoints), tf.int64))
193+
self.nrow.assign(tf.zeros((self.actual_cutpoints), tf.int64))
194+
self.ncol.assign(tf.zeros((self.preds_cutpoints), tf.int64))
195+
self.n.assign(0)
204196

205197
def reset_states(self):
206198
# Backwards compatibility alias of `reset_state`. New classes should

tensorflow_addons/metrics/tests/kendalls_tau_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def test_keras_binary_classification_model():
9090
x = np.random.rand(1000, 10).astype(np.float32)
9191
y = np.random.rand(1000, 1).astype(np.float32)
9292

93-
model.fit(x, y, epochs=1, verbose=0, batch_size=32)
93+
history = model.fit(x, y, epochs=1, verbose=0, batch_size=32)
94+
assert not any(np.isnan(history.history["kendalls_tau"]))
9495

9596

9697
def test_kendalls_tau_serialization():

0 commit comments

Comments
 (0)