Skip to content

Commit

Permalink
Merge branch 'experimental/sasrec' of https://github.com/MobileTeleSy…
Browse files Browse the repository at this point in the history
…stems/RecTools into feature/sasrec_torch_rank
  • Loading branch information
blondered committed Feb 11, 2025
2 parents af96437 + b12d1cc commit a785d57
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 16 deletions.
10 changes: 5 additions & 5 deletions rectools/models/nn/item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:

def get_item_inputs_offsets(self, items: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Get categorical item features and offsets for `items`."""
length_range = torch.arange(self.input_lengths.max().item(), device=self.device)
item_indexes = self.offsets[items].unsqueeze(-1) + length_range
length_mask = length_range < self.input_lengths[items].unsqueeze(-1)
item_emb_bag_inputs = self.emb_bag_inputs[item_indexes[length_mask]]
length_range = torch.arange(self.get_buffer("input_lengths").max().item(), device=self.device)
item_indexes = self.get_buffer("offsets")[items].unsqueeze(-1) + length_range
length_mask = length_range < self.get_buffer("input_lengths")[items].unsqueeze(-1)
item_emb_bag_inputs = self.get_buffer("emb_bag_inputs")[item_indexes[length_mask]]
item_offsets = torch.cat(
(torch.tensor([0], device=self.device), torch.cumsum(self.input_lengths[items], dim=0)[:-1])
(torch.tensor([0], device=self.device), torch.cumsum(self.get_buffer("input_lengths")[items], dim=0)[:-1])
)
return item_emb_bag_inputs, item_offsets

Expand Down
7 changes: 6 additions & 1 deletion rectools/models/nn/transformer_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,16 @@ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) ->
outputs["loss"] = self._calc_gbce_loss(pos_neg_logits, y, w, negatives)
outputs["pos_neg_logits"] = pos_neg_logits.squeeze()
else:
raise ValueError(f"loss {self.loss} is not supported")
outputs = self._calc_custom_loss_outputs(batch, batch_idx) # pragma: no cover

self.log(self.val_loss_name, outputs["loss"], on_step=False, on_epoch=True, prog_bar=self.verbose > 0)
return outputs

def _calc_custom_loss_outputs(
self, batch: tp.Dict[str, torch.Tensor], batch_idx: int
) -> tp.Dict[str, torch.Tensor]:
raise ValueError(f"loss {self.loss} is not supported") # pragma: no cover

def _get_full_catalog_logits(self, x: torch.Tensor) -> torch.Tensor:
item_embs, session_embs = self.torch_model(x)
logits = session_embs @ item_embs.T
Expand Down
48 changes: 48 additions & 0 deletions tests/models/nn/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.

import typing as tp
from functools import partial

import numpy as np
import pandas as pd
import pytest
import torch
from pytorch_lightning import Trainer, seed_everything

from rectools import ExternalIds
from rectools.columns import Columns
from rectools.dataset import Dataset
from rectools.models import BERT4RecModel
Expand Down Expand Up @@ -600,6 +602,30 @@ def data_preparator(self) -> BERT4RecDataPreparator:
mask_prob=0.5,
)

@pytest.fixture
def data_preparator_val_mask(self) -> BERT4RecDataPreparator:
def get_val_mask(interactions: pd.DataFrame, val_users: ExternalIds) -> np.ndarray:
rank = (
interactions.sort_values(Columns.Datetime, ascending=False, kind="stable")
.groupby(Columns.User, sort=False)
.cumcount()
+ 1
)
val_mask = (interactions[Columns.User].isin(val_users)) & (rank <= 1)
return val_mask.values

val_users = [10, 30]
get_val_mask_func = partial(get_val_mask, val_users=val_users)
return BERT4RecDataPreparator(
session_max_len=4,
n_negatives=2,
train_min_user_interactions=2,
mask_prob=0.5,
batch_size=4,
dataloader_num_workers=0,
get_val_mask_func=get_val_mask_func,
)

@pytest.mark.parametrize(
"train_batch",
(
Expand Down Expand Up @@ -666,6 +692,28 @@ def test_get_dataloader_recommend(
for key, value in actual.items():
assert torch.equal(value, recommend_batch[key])

@pytest.mark.parametrize(
"val_batch",
(
(
{
"x": torch.tensor([[0, 2, 4, 1]]),
"y": torch.tensor([[3]]),
"yw": torch.tensor([[1.0]]),
"negatives": torch.tensor([[[5, 2]]]),
}
),
),
)
def test_get_dataloader_val(
self, dataset: Dataset, data_preparator_val_mask: BERT4RecDataPreparator, val_batch: tp.List
) -> None:
data_preparator_val_mask.process_dataset_train(dataset)
dataloader = data_preparator_val_mask.get_dataloader_val()
actual = next(iter(dataloader)) # type: ignore
for key, value in actual.items():
assert torch.equal(value, val_batch[key])


class TestBERT4RecModelConfiguration:
def setup_method(self) -> None:
Expand Down
95 changes: 92 additions & 3 deletions tests/models/nn/test_item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from rectools.columns import Columns
from rectools.dataset import Dataset
from rectools.dataset.dataset import DatasetSchema, EntitySchema
from rectools.models.nn.item_net import (
CatFeaturesItemNet,
IdEmbeddingsItemNet,
Expand Down Expand Up @@ -101,6 +102,26 @@ def dataset_item_features(self) -> Dataset:
)
return ds

@pytest.fixture
def dataset_dense_item_features(self) -> Dataset:
item_features = pd.DataFrame(
[
[11, 1, 1],
[12, 1, 2],
[13, 1, 3],
[14, 2, 1],
[15, 2, 2],
[17, 2, 3],
],
columns=[Columns.Item, "f1", "f2"],
)
ds = Dataset.construct(
INTERACTIONS,
item_features_df=item_features,
make_dense_item_features=True,
)
return ds

def test_get_item_inputs_offsets(self, dataset_item_features: Dataset) -> None:
items = torch.from_numpy(
dataset_item_features.item_id_map.convert_to_internal(INTERACTIONS[Columns.Item].unique())
Expand All @@ -123,11 +144,11 @@ def test_create_from_dataset(self, n_factors: int, dataset_item_features: Datase

assert isinstance(cat_item_embeddings, CatFeaturesItemNet)

actual_offsets = cat_item_embeddings.offsets
actual_offsets = cat_item_embeddings.get_buffer("offsets")
actual_n_cat_feature_values = cat_item_embeddings.n_cat_feature_values
actual_embedding_dim = cat_item_embeddings.embedding_bag.embedding_dim
actual_emb_bag_inputs = cat_item_embeddings.emb_bag_inputs
actual_input_lengths = cat_item_embeddings.input_lengths
actual_emb_bag_inputs = cat_item_embeddings.get_buffer("emb_bag_inputs")
actual_input_lengths = cat_item_embeddings.get_buffer("input_lengths")

expected_offsets = torch.tensor([0, 0, 2, 4, 6, 8, 10])
expected_emb_bag_inputs = torch.tensor([0, 2, 1, 4, 0, 3, 1, 2, 1, 3, 1, 3])
Expand Down Expand Up @@ -208,6 +229,74 @@ def test_when_cat_item_features_is_none(
cat_features_item_net = CatFeaturesItemNet.from_dataset(ds, n_factors=10, dropout_rate=0.5)
assert cat_features_item_net is None

def test_warns_when_dataset_schema_features_are_dense(self, dataset_dense_item_features: Dataset) -> None:
dataset_schema_dict = dataset_dense_item_features.get_schema()
item_schema = EntitySchema(
n_hot=dataset_schema_dict["items"]["n_hot"],
id_map=dataset_schema_dict["items"]["id_map"],
features=dataset_schema_dict["items"]["features"],
)
user_schema = EntitySchema(
n_hot=dataset_schema_dict["users"]["n_hot"],
id_map=dataset_schema_dict["users"]["id_map"],
features=dataset_schema_dict["users"]["features"],
)
dataset_schema = DatasetSchema(
n_interactions=dataset_schema_dict["n_interactions"],
users=user_schema,
items=item_schema,
)
with pytest.warns() as record:
CatFeaturesItemNet.from_dataset_schema(dataset_schema, n_factors=5, dropout_rate=0.5)
assert (
str(record[0].message)
== """
Ignoring `CatFeaturesItemNet` block because
dataset item features are dense and unable to contain categorical features.
"""
)

def test_warns_when_dataset_schema_categorical_features_are_none(self) -> None:
item_features = pd.DataFrame(
[
[12, "f3", 1],
[13, "f3", 2],
[14, "f3", 3],
[15, "f3", 4],
[17, "f3", 5],
[16, "f3", 6],
],
columns=["id", "feature", "value"],
)
dataset = Dataset.construct(
INTERACTIONS,
item_features_df=item_features,
)
dataset_schema_dict = dataset.get_schema()
item_schema = EntitySchema(
n_hot=dataset_schema_dict["items"]["n_hot"],
id_map=dataset_schema_dict["items"]["id_map"],
features=dataset_schema_dict["items"]["features"],
)
user_schema = EntitySchema(
n_hot=dataset_schema_dict["users"]["n_hot"],
id_map=dataset_schema_dict["users"]["id_map"],
features=dataset_schema_dict["users"]["features"],
)
dataset_schema = DatasetSchema(
n_interactions=dataset_schema_dict["n_interactions"],
users=user_schema,
items=item_schema,
)
with pytest.warns() as record:
CatFeaturesItemNet.from_dataset_schema(dataset_schema, n_factors=5, dropout_rate=0.5)
assert (
str(record[0].message)
== """
Ignoring `CatFeaturesItemNet` block because dataset item features do not contain categorical features.
"""
)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
class TestSumOfEmbeddingsConstructor:
Expand Down
2 changes: 2 additions & 0 deletions tests/models/nn/test_sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ def get_val_mask(interactions: pd.DataFrame, val_users: ExternalIds) -> np.ndarr
session_max_len=3,
batch_size=4,
dataloader_num_workers=0,
n_negatives=2,
get_val_mask_func=get_val_mask_func,
)

Expand Down Expand Up @@ -855,6 +856,7 @@ def test_get_dataloader_train(
"x": torch.tensor([[0, 1, 3]]),
"y": torch.tensor([[2]]),
"yw": torch.tensor([[1.0]]),
"negatives": torch.tensor([[[4, 1]]]),
}
),
),
Expand Down
45 changes: 38 additions & 7 deletions tests/models/nn/test_transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd
import pytest
import torch
from pytest import FixtureRequest
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger

Expand Down Expand Up @@ -103,7 +104,10 @@ def dataset_item_features(self) -> Dataset:
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
@pytest.mark.parametrize("default_trainer", (True, False))
def test_save_load_for_unfitted_model(
self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset, default_trainer: bool, trainer: Trainer
self,
model_cls: tp.Type[TransformerModelBase],
dataset: Dataset,
default_trainer: bool,
) -> None:
config = {
"deterministic": True,
Expand Down Expand Up @@ -142,7 +146,6 @@ def test_save_load_for_fitted_model(
model_cls: tp.Type[TransformerModelBase],
dataset_item_features: Dataset,
default_trainer: bool,
trainer: Trainer,
) -> None:
config = {
"deterministic": True,
Expand All @@ -154,22 +157,24 @@ def test_save_load_for_fitted_model(
model.fit(dataset_item_features)
assert_save_load_do_not_change_model(model, dataset_item_features)

@pytest.mark.parametrize("test_dataset", ("dataset", "dataset_item_features"))
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_load_from_checkpoint(
self,
model_cls: tp.Type[TransformerModelBase],
tmp_path: str,
dataset_item_features: Dataset,
test_dataset: str,
request: FixtureRequest,
) -> None:

model = model_cls.from_config(
{
"deterministic": True,
"item_net_block_types": (IdEmbeddingsItemNet,), # TODO: add CatFeaturesItemNet
"item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet),
"get_trainer_func": custom_trainer_ckpt,
}
)
model.fit(dataset_item_features)
dataset = request.getfixturevalue(test_dataset)
model.fit(dataset)

assert model.fit_trainer is not None
if model.fit_trainer.log_dir is None:
Expand All @@ -179,7 +184,30 @@ def test_load_from_checkpoint(
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
assert isinstance(recovered_model, model_cls)

self._assert_same_reco(model, recovered_model, dataset_item_features)
self._assert_same_reco(model, recovered_model, dataset)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_raises_when_save_model_loaded_from_checkpoint(
self,
model_cls: tp.Type[TransformerModelBase],
dataset: Dataset,
) -> None:
model = model_cls.from_config(
{
"deterministic": True,
"item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet),
"get_trainer_func": custom_trainer_ckpt,
}
)
model.fit(dataset)
assert model.fit_trainer is not None
if model.fit_trainer.log_dir is None:
raise ValueError("No log dir")
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
with pytest.raises(RuntimeError):
with NamedTemporaryFile() as f:
recovered_model.save(f.name)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
@pytest.mark.parametrize("verbose", (1, 0))
Expand All @@ -190,12 +218,14 @@ def test_load_from_checkpoint(
(True, ["epoch", "step", "train_loss", "val_loss"]),
),
)
@pytest.mark.parametrize("loss", ("softmax", "BCE", "gBCE"))
def test_log_metrics(
self,
model_cls: tp.Type[TransformerModelBase],
dataset: Dataset,
tmp_path: str,
verbose: int,
loss: str,
is_val_mask_func: bool,
expected_columns: tp.List[str],
) -> None:
Expand All @@ -215,6 +245,7 @@ def test_log_metrics(
{
"verbose": verbose,
"get_val_mask_func": get_val_mask_func,
"loss": loss,
}
)
model._trainer = trainer # pylint: disable=protected-access
Expand Down

0 comments on commit a785d57

Please sign in to comment.