Skip to content

Commit

Permalink
add tf recommenders difference
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhaoyueCheng committed May 3, 2024
1 parent 3aa525a commit fcd8662
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 71 deletions.
29 changes: 20 additions & 9 deletions tensorflow_recommenders/experimental/models/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
from tensorflow_recommenders import tasks
from tensorflow_recommenders.layers import feature_interaction as feature_interaction_lib

class MaskedAUC(tf.keras.metrics.AUC):
def __init__(self, padding_label=-1, **kwargs):
super().__init__(from_logits=True, **kwargs)
self.padding_label = padding_label

def update_state(self, y_true, y_pred, sample_weight=None):
mask = tf.not_equal(y_true, self.padding_label)

y_true_masked = tf.boolean_mask(y_true, mask)
y_pred_masked = tf.boolean_mask(y_pred, mask)

return super().update_state(y_true_masked, y_pred_masked)

class Ranking(models.Model):
"""A configurable ranking model.
Expand Down Expand Up @@ -113,10 +125,11 @@ def __init__(
else:
self._task = tasks.Ranking(
loss=tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.NONE
reduction=tf.keras.losses.Reduction.NONE, from_logits=True
),
metrics=[
tf.keras.metrics.AUC(name="auc"),
MaskedAUC(name="mauc"),
tf.keras.metrics.AUC(name="auc", from_logits=True),
tf.keras.metrics.BinaryAccuracy(name="accuracy"),
],
prediction_metrics=[
Expand All @@ -130,18 +143,15 @@ def __init__(
def compute_loss(self,
inputs: Union[
# Tuple of (features, labels).
Tuple[
Dict[str, tf.Tensor],
tf.Tensor
],
# Tuple of (features, labels, sample weights).
Tuple[
Dict[str, tf.Tensor],
tf.Tensor,
Optional[tf.Tensor]
]
],
training: bool = False) -> tf.Tensor:
training: bool = False) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Computes the loss and metrics of the model.
Args:
Expand Down Expand Up @@ -191,14 +201,15 @@ def compute_loss(self,
"or a tuple of (features, labels, sample weights). "
"Got a length {len(inputs)} tuple instead: {inputs}."
)

feature = inputs
label = inputs['clicked']
outputs = self(features, training=training)

loss = self._task(labels, outputs, sample_weight=sample_weight)
loss = self._task(labels, outputs, sample_weight=None)
loss = tf.reduce_mean(loss)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
return loss / tf.distribute.get_strategy().num_replicas_in_sync
return loss / tf.distribute.get_strategy().num_replicas_in_sync, labels, outputs

def call(self, inputs: Dict[str, tf.Tensor]) -> tf.Tensor:
"""Executes forward and backward pass, returns loss.
Expand Down
45 changes: 45 additions & 0 deletions tensorflow_recommenders/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def __init__(
super().__init__(**kwargs)

self._sublayers = []
self._units = units
self.use_bias = use_bias
self.activation = activation
self.final_activation = final_activation

for num_units in units[:-1]:
self._sublayers.append(
Expand All @@ -53,6 +57,47 @@ def __init__(
tf.keras.layers.Dense(
units[-1], activation=final_activation, use_bias=use_bias))



# def get_uniform_initializer(self, bottom_dim):
# limit = tf.math.sqrt(1.0 / bottom_dim)
# return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

def build(self, input_shape):
# The first layer's bottom_dim comes from the input shape
bottom_dim = input_shape[1]
for _, num_units in enumerate(self._units[:-1]):
self._sublayers.append(
tf.keras.layers.Dense(
num_units,
activation=self.activation,
use_bias=self.use_bias,
kernel_initializer=tf.keras.initializers.HeUniform(),
bias_initializer=tf.keras.initializers.RandomUniform(
minval=-tf.math.sqrt(1.0 / bottom_dim),
maxval=tf.math.sqrt(1.0 / bottom_dim),
seed=0
),
)
)
bottom_dim = num_units # Update bottom_dim for the next layer

# Add the final layer
self._sublayers.append(
tf.keras.layers.Dense(
self._units[-1],
activation=self.final_activation,
use_bias=self.use_bias,
kernel_initializer=tf.keras.initializers.HeUniform(),
bias_initializer=tf.keras.initializers.RandomUniform(
minval=-tf.math.sqrt(1.0 / bottom_dim),
maxval=tf.math.sqrt(1.0 / bottom_dim),
seed=0
),
)
)
super().build(input_shape)

def call(self, x: tf.Tensor) -> tf.Tensor:
"""Performs the forward computation of the block."""
for layer in self._sublayers:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,19 @@ def __init__(
num_layers: Optional[int] = 3,
use_bias: bool = True,
kernel_initializer: Union[
Text, tf.keras.initializers.Initializer] = "truncated_normal",
bias_initializer: Union[Text,
tf.keras.initializers.Initializer] = "zeros",
kernel_regularizer: Union[Text, None,
tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Union[Text, None,
tf.keras.regularizers.Regularizer] = None,
**kwargs):
Text, tf.keras.initializers.Initializer
] = "he_uniform",
bias_initializer: Union[
Text, tf.keras.initializers.Initializer
] = "zeros",
kernel_regularizer: Union[
Text, None, tf.keras.regularizers.Regularizer
] = None,
bias_regularizer: Union[
Text, None, tf.keras.regularizers.Regularizer
] = None,
**kwargs
):

super(MultiLayerDCN, self).__init__(**kwargs)

Expand All @@ -113,23 +118,26 @@ def build(self, input_shape):
last_dim = input_shape[-1]
self._dense_u_kernels, self._dense_v_kernels = [], []

for _ in range(self._num_layers):
self._dense_u_kernels.append(tf.keras.layers.Dense(
self._projection_dim,
kernel_initializer=_clone_initializer(self._kernel_initializer),
kernel_regularizer=self._kernel_regularizer,
use_bias=False,
dtype=self.dtype,
))
self._dense_v_kernels.append(tf.keras.layers.Dense(
last_dim,
kernel_initializer=_clone_initializer(self._kernel_initializer),
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
use_bias=self._use_bias,
dtype=self.dtype,
))
for i in range(self._num_layers):
self._dense_u_kernels.append(
tf.keras.layers.Dense(
self._projection_dim,
kernel_initializer='glorot_normal',
kernel_regularizer=self._kernel_regularizer,
use_bias=False,
)
)
self._dense_v_kernels.append(
tf.keras.layers.Dense(
last_dim,
kernel_initializer='glorot_normal',
bias_initializer='zeros',
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
use_bias=True,
)
)


self.built = True

Expand Down
51 changes: 14 additions & 37 deletions tensorflow_recommenders/models/base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
# Copyright 2024 The TensorFlow Recommenders Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# lint-as: python3
"""Base model."""

import tensorflow as tf

import numpy as np

class Model(tf.keras.Model):
"""Base model for TFRS models.
Expand Down Expand Up @@ -65,40 +50,32 @@ def train_step(self, inputs):
"""Custom train step using the `compute_loss` method."""

with tf.GradientTape() as tape:
loss = self.compute_loss(inputs, training=True)

# Handle regularization losses as well.
regularization_loss = tf.reduce_sum(
[tf.reduce_sum(loss) for loss in self.losses]
)
loss, labels, outputs = self.compute_loss(inputs, training=True)

total_loss = loss + regularization_loss
total_loss = loss # + regularization_loss

gradients = tape.gradient(total_loss, self.trainable_variables)

self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

metrics = {metric.name: metric.result() for metric in self.metrics}
self.compiled_metrics.update_state(labels, outputs)
metrics = self.get_metrics_result()
metrics["loss"] = loss
metrics["regularization_loss"] = regularization_loss
metrics["regularization_loss"] = 0 # regularization_loss
metrics["total_loss"] = total_loss

return metrics

def test_step(self, inputs):
"""Custom test step using the `compute_loss` method."""

loss = self.compute_loss(inputs, training=False)

# Handle regularization losses as well.
regularization_loss = tf.reduce_sum(
[tf.reduce_sum(loss) for loss in self.losses]
)

total_loss = loss + regularization_loss

metrics = {metric.name: metric.result() for metric in self.metrics}
loss, labels, outputs = self.compute_loss(inputs, training=False)

total_loss = loss # + regularization_loss
self.compiled_metrics.update_state(labels, outputs)
metrics = self.get_metrics_result()
metrics["loss"] = loss
metrics["regularization_loss"] = regularization_loss
metrics["regularization_loss"] = 0 # regularization_loss
metrics["total_loss"] = total_loss

return metrics
return metrics # , labels, outputs

0 comments on commit fcd8662

Please sign in to comment.