diff --git a/tensorflow_recommenders/layers/factorized_top_k.py b/tensorflow_recommenders/layers/factorized_top_k.py index b6a7f45..0316c3c 100644 --- a/tensorflow_recommenders/layers/factorized_top_k.py +++ b/tensorflow_recommenders/layers/factorized_top_k.py @@ -466,7 +466,7 @@ def top_k(state: Tuple[tf.Tensor, tf.Tensor], def enumerate_rows(batch: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: """Enumerates rows in each batch using a total element counter.""" - starting_counter = self._counter.read_value() + starting_counter = self._counter.value end_counter = self._counter.assign_add(tf.shape(batch)[0]) return tf.range(starting_counter, end_counter), batch