From 6affea43cb5e5d138a48958e72a222687b7f066d Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 22 Jan 2025 13:18:12 +0300 Subject: [PATCH 01/25] transformer base and sasrec config --- rectools/models/nn/sasrec.py | 101 +++++++++-- rectools/models/nn/transformer_base.py | 234 ++++++++++++++++++++----- tests/models/nn/test_sasrec.py | 206 +++++++++++++++++++++- 3 files changed, 484 insertions(+), 57 deletions(-) diff --git a/rectools/models/nn/sasrec.py b/rectools/models/nn/sasrec.py index 28543f20..5b7a5aad 100644 --- a/rectools/models/nn/sasrec.py +++ b/rectools/models/nn/sasrec.py @@ -13,20 +13,23 @@ # limitations under the License. import typing as tp -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple +import typing_extensions as tpe import numpy as np import torch from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import Accelerator from torch import nn from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase from .transformer_base import ( PADDING_VALUE, + SessionEncoderDataPreparatorType, SessionEncoderLightningModule, SessionEncoderLightningModuleBase, + TransformerLayersType, TransformerModelBase, + TransformerModelConfig, ) from .transformer_data_preparator import SessionEncoderDataPreparatorBase from .transformer_net_blocks import ( @@ -157,7 +160,14 @@ def forward( return seqs -class SASRecModel(TransformerModelBase): +class SASRecModelConfig(TransformerModelConfig): + """SASRecModel config.""" + + data_preparator_type: SessionEncoderDataPreparatorType = SASRecDataPreparator + transformer_layers_type: TransformerLayersType = SASRecTransformerLayers + + +class SASRecModel(TransformerModelBase[SASRecModelConfig]): """ SASRec model. @@ -231,6 +241,8 @@ class SASRecModel(TransformerModelBase): Type of lightning module defining training procedure. """ + config_class = SASRecModelConfig + def __init__( # pylint: disable=too-many-arguments, too-many-locals self, n_blocks: int = 1, @@ -250,7 +262,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals epochs: int = 3, verbose: int = 0, deterministic: bool = False, - recommend_device: Union[str, Accelerator] = "auto", + recommend_device: str = "auto", recommend_n_threads: int = 0, recommend_use_gpu_ranking: bool = True, train_min_user_interactions: int = 2, @@ -272,7 +284,10 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals use_key_padding_mask=use_key_padding_mask, dropout_rate=dropout_rate, session_max_len=session_max_len, + dataloader_num_workers=dataloader_num_workers, + batch_size=batch_size, loss=loss, + n_negatives=n_negatives, gbce_t=gbce_t, lr=lr, epochs=epochs, @@ -281,16 +296,82 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_device=recommend_device, recommend_n_threads=recommend_n_threads, recommend_use_gpu_ranking=recommend_use_gpu_ranking, + train_min_user_interactions=train_min_user_interactions, trainer=trainer, item_net_block_types=item_net_block_types, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, ) - self.data_preparator = data_preparator_type( - session_max_len=session_max_len, - n_negatives=n_negatives if loss != "softmax" else None, - batch_size=batch_size, - dataloader_num_workers=dataloader_num_workers, + + def _init_data_preparator(self) -> None: + self.data_preparator = self.data_preparator_type( + session_max_len=self.session_max_len, + n_negatives=self.n_negatives if self.loss != "softmax" else None, + batch_size=self.batch_size, + dataloader_num_workers=self.dataloader_num_workers, item_extra_tokens=(PADDING_VALUE,), - train_min_user_interactions=train_min_user_interactions, + train_min_user_interactions=self.train_min_user_interactions, + ) + + def _get_config(self) -> TransformerModelConfig: + return SASRecModelConfig( + cls=self.__class__, + n_blocks=self.n_blocks, + n_heads=self.n_heads, + n_factors=self.n_factors, + use_pos_emb=self.use_pos_emb, + use_causal_attn=self.use_causal_attn, + use_key_padding_mask=self.use_key_padding_mask, + dropout_rate=self.dropout_rate, + session_max_len=self.session_max_len, + dataloader_num_workers=self.dataloader_num_workers, + batch_size=self.batch_size, + loss=self.loss, + n_negatives=self.n_negatives, + gbce_t=self.gbce_t, + lr=self.lr, + epochs=self.epochs, + verbose=self.verbose, + deterministic=self.deterministic, + recommend_device=self.recommend_device, + recommend_n_threads=self.recommend_n_threads, + recommend_use_gpu_ranking=self.recommend_use_gpu_ranking, + train_min_user_interactions=self.train_min_user_interactions, + item_net_block_types=self.item_net_block_types, + pos_encoding_type=self.pos_encoding_type, + transformer_layers_type=self.transformer_layers_type, + data_preparator_type=self.data_preparator_type, + lightning_module_type=self.lightning_module_type, + ) + + @classmethod + def _from_config(cls, config: TransformerModelConfig) -> tpe.Self: + return cls( + trainer=None, + n_blocks=config.n_blocks, + n_heads=config.n_heads, + n_factors=config.n_factors, + use_pos_emb=config.use_pos_emb, + use_causal_attn=config.use_causal_attn, + use_key_padding_mask=config.use_key_padding_mask, + dropout_rate=config.dropout_rate, + session_max_len=config.session_max_len, + dataloader_num_workers=config.dataloader_num_workers, + batch_size=config.batch_size, + loss=config.loss, + n_negatives=config.n_negatives, + gbce_t=config.gbce_t, + lr=config.lr, + epochs=config.epochs, + verbose=config.verbose, + deterministic=config.deterministic, + recommend_device=config.recommend_device, + recommend_n_threads=config.recommend_n_threads, + recommend_use_gpu_ranking=config.recommend_use_gpu_ranking, + train_min_user_interactions=config.train_min_user_interactions, + item_net_block_types=config.item_net_block_types, + pos_encoding_type=config.pos_encoding_type, + transformer_layers_type=config.transformer_layers_type, + data_preparator_type=config.data_preparator_type, + lightning_module_type=config.lightning_module_type, ) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 64baf995..cf27a5be 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -17,15 +17,17 @@ import numpy as np import torch +import typing_extensions as tpe from implicit.gpu import HAS_CUDA +from pydantic import BeforeValidator, PlainSerializer from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.accelerators import Accelerator from rectools import ExternalIds from rectools.dataset import Dataset -from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase +from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig from rectools.models.rank import Distance, ImplicitRanker from rectools.types import InternalIdsArray +from rectools.utils.misc import get_class_or_function_full_path, import_object from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase, ItemNetConstructor from .transformer_data_preparator import SessionEncoderDataPreparatorBase @@ -399,20 +401,126 @@ def _xavier_normal_init(self) -> None: torch.nn.init.xavier_normal_(param.data) +# #### -------------- Transformer Config -------------- #### # + + +def _get_class_obj(spec: tp.Any) -> tp.Any: + if not isinstance(spec, str): + return spec + return import_object(spec) + + +def _get_class_obj_sequence(spec: tp.Sequence[tp.Any]) -> tp.Any: + return tuple(map(_get_class_obj, spec)) + + +def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Sequence[str]: + return tuple(map(get_class_or_function_full_path, obj)) + + +PositionalEncodingType = tpe.Annotated[ + tp.Type[PositionalEncodingBase], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + +TransformerLayersType = tpe.Annotated[ + tp.Type[TransformerLayersBase], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + +SessionEncoderLightningModuleType = tpe.Annotated[ + tp.Type[SessionEncoderLightningModuleBase], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + +SessionEncoderDataPreparatorType = tpe.Annotated[ + tp.Type[SessionEncoderDataPreparatorBase], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + +ItemNetBlockTypes = tpe.Annotated[ + tp.Sequence[tp.Type[ItemNetBase]], + BeforeValidator(_get_class_obj_sequence), + PlainSerializer( + func=_serialize_type_sequence, + return_type=str, + when_used="json", + ), +] + + +class TransformerModelConfig(ModelConfig): + """Transformer model base config.""" + + data_preparator_type: SessionEncoderDataPreparatorType + n_blocks: int = 1 + n_heads: int = 1 + n_factors: int = 128 + use_pos_emb: bool = True + use_causal_attn: bool = True + use_key_padding_mask: bool = False + dropout_rate: float = 0.2 + session_max_len: int = 32 + dataloader_num_workers: int = 0 + batch_size: int = 128 + loss: str = "softmax" + n_negatives: int = 1 + gbce_t: float = 0.2 + lr: float = 0.01 + epochs: int = 3 + verbose: int = 0 + deterministic: bool = False + recommend_device: str = "auto" # custom accelerators not supported + recommend_n_threads: int = 0 + recommend_use_gpu_ranking: bool = True + train_min_user_interactions: int = 2 + item_net_block_types: ItemNetBlockTypes = (IdEmbeddingsItemNet, CatFeaturesItemNet) + pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding + transformer_layers_type: TransformerLayersType = PreLNTransformerLayers + lightning_module_type: SessionEncoderLightningModuleType = SessionEncoderLightningModule + + +TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig) + + # #### -------------- Transformer Model Base -------------- #### # -class TransformerModelBase(ModelBase): # pylint: disable=too-many-instance-attributes +class TransformerModelBase(ModelBase[TransformerModelConfig_T]): # pylint: disable=too-many-instance-attributes """ Base model for all recommender algorithms that work on transformer architecture (e.g. SASRec, Bert4Rec). To create a custom transformer model it is necessary to inherit from this class and write self.data_preparator initialization logic. """ + config_class: tp.Type[TransformerModelConfig_T] + u2i_dist = Distance.DOT + i2i_dist = Distance.COSINE + def __init__( # pylint: disable=too-many-arguments, too-many-locals self, - transformer_layers_type: tp.Type[TransformerLayersBase], data_preparator_type: tp.Type[SessionEncoderDataPreparatorBase], + transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers, n_blocks: int = 1, n_heads: int = 1, n_factors: int = 128, @@ -421,58 +529,101 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals use_key_padding_mask: bool = False, dropout_rate: float = 0.2, session_max_len: int = 32, + dataloader_num_workers: int = 0, + batch_size: int = 128, loss: str = "softmax", + n_negatives: int = 1, gbce_t: float = 0.5, lr: float = 0.01, epochs: int = 3, verbose: int = 0, deterministic: bool = False, - recommend_device: tp.Union[str, Accelerator] = "auto", + recommend_device: str = "auto", recommend_n_threads: int = 0, recommend_use_gpu_ranking: bool = True, + train_min_user_interactions: int = 2, trainer: tp.Optional[Trainer] = None, item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, **kwargs: tp.Any, ) -> None: - super().__init__(verbose) - self.recommend_n_threads = recommend_n_threads + super().__init__(verbose=verbose) + self.transformer_layers_type = transformer_layers_type + self.data_preparator_type = data_preparator_type + self.n_blocks = n_blocks + self.n_heads = n_heads + self.n_factors = n_factors + self.use_pos_emb = use_pos_emb + self.use_causal_attn = use_causal_attn + self.use_key_padding_mask = use_key_padding_mask + self.dropout_rate = dropout_rate + self.session_max_len = session_max_len + self.dataloader_num_workers = dataloader_num_workers + self.batch_size = batch_size + self.loss = loss + self.n_negatives = n_negatives + self.gbce_t = gbce_t + self.lr = lr + self.epochs = epochs + self.deterministic = deterministic self.recommend_device = recommend_device + self.recommend_n_threads = recommend_n_threads self.recommend_use_gpu_ranking = recommend_use_gpu_ranking - self._torch_model = TransformerBasedSessionEncoder( - n_blocks=n_blocks, - n_factors=n_factors, - n_heads=n_heads, - session_max_len=session_max_len, - dropout_rate=dropout_rate, - use_pos_emb=use_pos_emb, - use_causal_attn=use_causal_attn, - use_key_padding_mask=use_key_padding_mask, - transformer_layers_type=transformer_layers_type, - item_net_block_types=item_net_block_types, - pos_encoding_type=pos_encoding_type, - ) - self.lightning_model: SessionEncoderLightningModuleBase + self.train_min_user_interactions = train_min_user_interactions + self.item_net_block_types = item_net_block_types + self.pos_encoding_type = pos_encoding_type self.lightning_module_type = lightning_module_type - self.fit_trainer: Trainer + + self._init_torch_model() + self._init_data_preparator() if trainer is None: - self._trainer = Trainer( - max_epochs=epochs, - min_epochs=epochs, - deterministic=deterministic, - enable_progress_bar=verbose > 0, - enable_model_summary=verbose > 0, - logger=verbose > 0, - ) + self._init_trainer() else: self._trainer = trainer + + self.lightning_model: SessionEncoderLightningModuleBase self.data_preparator: SessionEncoderDataPreparatorBase - self.u2i_dist = Distance.DOT - self.i2i_dist = Distance.COSINE - self.lr = lr - self.loss = loss - self.gbce_t = gbce_t + self.fit_trainer: Trainer + + def _init_data_preparator(self) -> None: + raise NotImplementedError() + + def _init_trainer(self) -> None: + self._trainer = Trainer( + max_epochs=self.epochs, + min_epochs=self.epochs, + deterministic=self.deterministic, + enable_progress_bar=self.verbose > 0, + enable_model_summary=self.verbose > 0, + logger=self.verbose > 0, + enable_checkpointing=False, + devices=1, + ) + + def _init_torch_model(self) -> None: + self._torch_model = TransformerBasedSessionEncoder( + n_blocks=self.n_blocks, + n_factors=self.n_factors, + n_heads=self.n_heads, + session_max_len=self.session_max_len, + dropout_rate=self.dropout_rate, + use_pos_emb=self.use_pos_emb, + use_causal_attn=self.use_causal_attn, + use_key_padding_mask=self.use_key_padding_mask, + transformer_layers_type=self.transformer_layers_type, + item_net_block_types=self.item_net_block_types, + pos_encoding_type=self.pos_encoding_type, + ) + + def _init_lightning_model(self, torch_model: TransformerBasedSessionEncoder, n_item_extra_tokens: int) -> None: + self.lightning_model = self.lightning_module_type( + torch_model=torch_model, + lr=self.lr, + loss=self.loss, + gbce_t=self.gbce_t, + n_item_extra_tokens=n_item_extra_tokens, + ) def _fit( self, @@ -484,14 +635,7 @@ def _fit( torch_model = deepcopy(self._torch_model) torch_model.construct_item_net(processed_dataset) - n_item_extra_tokens = self.data_preparator.n_item_extra_tokens - self.lightning_model = self.lightning_module_type( - torch_model=torch_model, - lr=self.lr, - loss=self.loss, - gbce_t=self.gbce_t, - n_item_extra_tokens=n_item_extra_tokens, - ) + self._init_lightning_model(torch_model, self.data_preparator.n_item_extra_tokens) self.fit_trainer = deepcopy(self._trainer) self.fit_trainer.fit(self.lightning_model, train_dataloader) @@ -562,8 +706,8 @@ def _recommend_i2i( sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() item_embs = self.lightning_model.item_embs.detach().cpu().numpy() - # TODO: i2i reco do not need filtering viewed. And user most of the times has GPU - # Should we use torch dot and topk? Should be faster + # TODO: i2i reco do not need filtering viewed and user most of the times has GPU + # We should test if torch `topk`` is be faster ranker = ImplicitRanker( self.i2i_dist, diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index 4b86ca2a..375f1ce4 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -24,8 +24,13 @@ from rectools.dataset import Dataset, IdMap, Interactions from rectools.models import SASRecModel from rectools.models.nn.item_net import IdEmbeddingsItemNet -from rectools.models.nn.sasrec import PADDING_VALUE, SASRecDataPreparator -from tests.models.utils import assert_second_fit_refits_model +from rectools.models.nn.sasrec import PADDING_VALUE, SASRecDataPreparator, SASRecTransformerLayers +from rectools.models.nn.transformer_base import LearnableInversePositionalEncoding, SessionEncoderLightningModule +from tests.models.data import DATASET +from tests.models.utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_second_fit_refits_model, +) from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal # TODO: add tests with BCE and GBCE @@ -651,3 +656,200 @@ def test_get_dataloader_recommend( actual = next(iter(dataloader)) for key, value in actual.items(): assert torch.equal(value, recommend_batch[key]) + + +class TestSASRecModelConfiguration: + def setup_method(self) -> None: + self._seed_everything() + + def _seed_everything(self) -> None: + torch.use_deterministic_algorithms(True) + seed_everything(32, workers=True) + + def test_from_config(self) -> None: + config = { + "n_blocks": 2, + "n_heads": 4, + "n_factors": 64, + "use_pos_emb": False, + "use_causal_attn": False, + "use_key_padding_mask": True, + "dropout_rate": 0.5, + "session_max_len": 10, + "dataloader_num_workers": 5, + "batch_size": 1024, + "loss": "BCE", + "n_negatives": 10, + "gbce_t": 0.5, + "lr": 0.001, + "epochs": 10, + "verbose": 1, + "deterministic": True, + "recommend_device": "auto", + "recommend_n_threads": 0, + "recommend_use_gpu_ranking": True, + "train_min_user_interactions": 5, + "item_net_block_types": (IdEmbeddingsItemNet,), + "pos_encoding_type": LearnableInversePositionalEncoding, + "transformer_layers_type": SASRecTransformerLayers, + "data_preparator_type": SASRecDataPreparator, + "lightning_module_type": SessionEncoderLightningModule, + } + model = SASRecModel.from_config(config) + assert model.n_blocks == 2 + assert model.n_heads == 4 + assert model.n_factors == 64 + assert model.use_pos_emb is False + assert model.use_causal_attn is False + assert model.use_key_padding_mask is True + assert model.dropout_rate == 0.5 + assert model.session_max_len == 10 + assert model.dataloader_num_workers == 5 + assert model.batch_size == 1024 + assert model.loss == "BCE" + assert model.n_negatives == 10 + assert model.gbce_t == 0.5 + assert model.lr == 0.001 + assert model.epochs == 10 + assert model.verbose == 1 + assert model.deterministic is True + assert model.recommend_device == "auto" + assert model.recommend_n_threads == 0 + assert model.recommend_use_gpu_ranking is True + assert model.train_min_user_interactions == 5 + assert model._trainer is not None + assert model.item_net_block_types == (IdEmbeddingsItemNet,) + assert model.pos_encoding_type == LearnableInversePositionalEncoding + assert model.transformer_layers_type == SASRecTransformerLayers + assert model.data_preparator_type == SASRecDataPreparator + assert model.lightning_module_type == SessionEncoderLightningModule + + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config(self, simple_types: bool) -> None: + model = SASRecModel( + n_blocks=2, + n_heads=4, + n_factors=64, + use_pos_emb=False, + use_causal_attn=False, + use_key_padding_mask=True, + dropout_rate=0.5, + session_max_len=10, + dataloader_num_workers=5, + batch_size=1024, + loss="BCE", + n_negatives=10, + gbce_t=0.5, + lr=0.001, + epochs=10, + verbose=1, + deterministic=True, + recommend_device="auto", + recommend_n_threads=0, + recommend_use_gpu_ranking=True, + train_min_user_interactions=5, + item_net_block_types=(IdEmbeddingsItemNet,), + pos_encoding_type=LearnableInversePositionalEncoding, + transformer_layers_type=SASRecTransformerLayers, + data_preparator_type=SASRecDataPreparator, + lightning_module_type=SessionEncoderLightningModule, + ) + config = model.get_config(simple_types=simple_types) + expected = { + "cls": "SASRecModel" if simple_types else SASRecModel, + "n_blocks": 2, + "n_heads": 4, + "n_factors": 64, + "use_pos_emb": False, + "use_causal_attn": False, + "use_key_padding_mask": True, + "dropout_rate": 0.5, + "session_max_len": 10, + "dataloader_num_workers": 5, + "batch_size": 1024, + "loss": "BCE", + "n_negatives": 10, + "gbce_t": 0.5, + "lr": 0.001, + "epochs": 10, + "verbose": 1, + "deterministic": True, + "recommend_device": "auto", + "recommend_n_threads": 0, + "recommend_use_gpu_ranking": True, + "train_min_user_interactions": 5, + "item_net_block_types": ( + ["rectools.models.nn.item_net.IdEmbeddingsItemNet"] if simple_types else (IdEmbeddingsItemNet,) + ), + "pos_encoding_type": ( + "rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding" + if simple_types + else LearnableInversePositionalEncoding + ), + "transformer_layers_type": ( + "rectools.models.nn.sasrec.SASRecTransformerLayers" if simple_types else SASRecTransformerLayers + ), + "data_preparator_type": ( + "rectools.models.nn.sasrec.SASRecDataPreparator" if simple_types else SASRecDataPreparator + ), + "lightning_module_type": ( + "rectools.models.nn.transformer_base.SessionEncoderLightningModule" + if simple_types + else SessionEncoderLightningModule + ), + } + assert config == expected + + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: + initial_config = { + "n_blocks": 1, + "n_heads": 1, + "n_factors": 10, + "use_pos_emb": False, + "use_causal_attn": False, + "use_key_padding_mask": True, + "dropout_rate": 0.5, + "session_max_len": 5, + "dataloader_num_workers": 1, + "batch_size": 100, + "loss": "BCE", + "n_negatives": 4, + "gbce_t": 0.5, + "lr": 0.001, + "epochs": 1, + "verbose": 0, + "deterministic": True, + "recommend_device": "auto", + "recommend_n_threads": 0, + "recommend_use_gpu_ranking": True, + "train_min_user_interactions": 2, + "item_net_block_types": (IdEmbeddingsItemNet,), + "pos_encoding_type": LearnableInversePositionalEncoding, + "transformer_layers_type": SASRecTransformerLayers, + "data_preparator_type": SASRecDataPreparator, + "lightning_module_type": SessionEncoderLightningModule, + } + + dataset = DATASET + model = SASRecModel + + def get_reco(model: SASRecModel) -> pd.DataFrame: + return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) + + model_1 = model.from_config(initial_config) + reco_1 = get_reco(model_1) + config_1 = model_1.get_config(simple_types=simple_types) + + self._seed_everything() + model_2 = model.from_config(config_1) + reco_2 = get_reco(model_2) + config_2 = model_2.get_config(simple_types=simple_types) + + assert config_1 == config_2 + pd.testing.assert_frame_equal(reco_1, reco_2) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, int] = {} + model = SASRecModel() + assert_default_config_and_default_model_params_are_the_same(model, default_config) From 3fd3217c527b6db6b88e5f3b648278ce91517b8d Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 22 Jan 2025 14:34:57 +0300 Subject: [PATCH 02/25] tests work linters not --- rectools/models/nn/bert4rec.py | 126 +++++++++- rectools/models/nn/sasrec.py | 7 +- rectools/models/nn/transformer_base.py | 12 +- .../models/nn/transformer_data_preparator.py | 1 + tests/models/nn/test_bertrec.py | 238 ++++++++++++++++++ 5 files changed, 364 insertions(+), 20 deletions(-) create mode 100644 tests/models/nn/test_bertrec.py diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index 45c65bd0..a62eea5a 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -17,15 +17,21 @@ import numpy as np import torch +import typing_extensions as tpe +from pydantic import BeforeValidator, PlainSerializer from pytorch_lightning import Trainer from pytorch_lightning.accelerators import Accelerator +from rectools.utils.misc import get_class_or_function_full_path + from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase -from .transformer_base import ( +from .transformer_base import ( # SessionEncoderDataPreparatorType_T, PADDING_VALUE, SessionEncoderLightningModule, SessionEncoderLightningModuleBase, TransformerModelBase, + TransformerModelConfig, + _get_class_obj, ) from .transformer_data_preparator import SessionEncoderDataPreparatorBase from .transformer_net_blocks import ( @@ -113,7 +119,26 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> D return {"x": torch.LongTensor(x)} -class BERT4RecModel(TransformerModelBase): +BERT4RecDataPreparatorType = tpe.Annotated[ + tp.Type[BERT4RecDataPreparator], + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + + +class BERT4RecModelConfig(TransformerModelConfig): + """BERT4RecModel config.""" + + data_preparator_type: BERT4RecDataPreparatorType = BERT4RecDataPreparator + use_key_padding_mask: bool = True + mask_prob: float = 0.15 + + +class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): """ BERT4Rec model. @@ -189,6 +214,8 @@ class BERT4RecModel(TransformerModelBase): Type of lightning module defining training procedure. """ + config_class = BERT4RecModelConfig + def __init__( # pylint: disable=too-many-arguments, too-many-locals self, n_blocks: int = 1, @@ -220,6 +247,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals data_preparator_type: tp.Type[BERT4RecDataPreparator] = BERT4RecDataPreparator, lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, ): + self.mask_prob = mask_prob + super().__init__( transformer_layers_type=transformer_layers_type, data_preparator_type=data_preparator_type, @@ -230,27 +259,98 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals use_causal_attn=use_causal_attn, use_key_padding_mask=use_key_padding_mask, dropout_rate=dropout_rate, + session_max_len=session_max_len, + dataloader_num_workers=dataloader_num_workers, + batch_size=batch_size, + loss=loss, + n_negatives=n_negatives, + gbce_t=gbce_t, + lr=lr, epochs=epochs, verbose=verbose, deterministic=deterministic, recommend_device=recommend_device, recommend_n_threads=recommend_n_threads, recommend_use_gpu_ranking=recommend_use_gpu_ranking, - loss=loss, - gbce_t=gbce_t, - lr=lr, - session_max_len=session_max_len + 1, + train_min_user_interactions=train_min_user_interactions, trainer=trainer, item_net_block_types=item_net_block_types, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, ) - self.data_preparator = data_preparator_type( - session_max_len=session_max_len, - n_negatives=n_negatives if loss != "softmax" else None, - batch_size=batch_size, - dataloader_num_workers=dataloader_num_workers, - train_min_user_interactions=train_min_user_interactions, + + def _init_data_preparator(self) -> None: # TODO: negative losses are not working now + self.data_preparator: SessionEncoderDataPreparatorBase = self.data_preparator_type( + session_max_len=self.session_max_len, # -1 + n_negatives=self.n_negatives if self.loss != "softmax" else None, + batch_size=self.batch_size, + dataloader_num_workers=self.dataloader_num_workers, + train_min_user_interactions=self.train_min_user_interactions, item_extra_tokens=(PADDING_VALUE, MASKING_VALUE), - mask_prob=mask_prob, + mask_prob=self.mask_prob, + ) + + def _get_config(self) -> BERT4RecModelConfig: + return BERT4RecModelConfig( + cls=self.__class__, + n_blocks=self.n_blocks, + n_heads=self.n_heads, + n_factors=self.n_factors, + use_pos_emb=self.use_pos_emb, + use_causal_attn=self.use_causal_attn, + use_key_padding_mask=self.use_key_padding_mask, + dropout_rate=self.dropout_rate, + session_max_len=self.session_max_len, + dataloader_num_workers=self.dataloader_num_workers, + batch_size=self.batch_size, + loss=self.loss, + n_negatives=self.n_negatives, + gbce_t=self.gbce_t, + lr=self.lr, + epochs=self.epochs, + verbose=self.verbose, + deterministic=self.deterministic, + recommend_device=self.recommend_device, + recommend_n_threads=self.recommend_n_threads, + recommend_use_gpu_ranking=self.recommend_use_gpu_ranking, + train_min_user_interactions=self.train_min_user_interactions, + item_net_block_types=self.item_net_block_types, + pos_encoding_type=self.pos_encoding_type, + transformer_layers_type=self.transformer_layers_type, + data_preparator_type=self.data_preparator_type, + lightning_module_type=self.lightning_module_type, + mask_prob=self.mask_prob, + ) + + @classmethod + def _from_config(cls, config: BERT4RecModelConfig) -> tpe.Self: + return cls( + trainer=None, + n_blocks=config.n_blocks, + n_heads=config.n_heads, + n_factors=config.n_factors, + use_pos_emb=config.use_pos_emb, + use_causal_attn=config.use_causal_attn, + use_key_padding_mask=config.use_key_padding_mask, + dropout_rate=config.dropout_rate, + session_max_len=config.session_max_len, + dataloader_num_workers=config.dataloader_num_workers, + batch_size=config.batch_size, + loss=config.loss, + n_negatives=config.n_negatives, + gbce_t=config.gbce_t, + lr=config.lr, + epochs=config.epochs, + verbose=config.verbose, + deterministic=config.deterministic, + recommend_device=config.recommend_device, + recommend_n_threads=config.recommend_n_threads, + recommend_use_gpu_ranking=config.recommend_use_gpu_ranking, + train_min_user_interactions=config.train_min_user_interactions, + item_net_block_types=config.item_net_block_types, + pos_encoding_type=config.pos_encoding_type, + transformer_layers_type=config.transformer_layers_type, + data_preparator_type=config.data_preparator_type, + lightning_module_type=config.lightning_module_type, + mask_prob=config.mask_prob, ) diff --git a/rectools/models/nn/sasrec.py b/rectools/models/nn/sasrec.py index 5b7a5aad..7744d446 100644 --- a/rectools/models/nn/sasrec.py +++ b/rectools/models/nn/sasrec.py @@ -14,10 +14,10 @@ import typing as tp from typing import Dict, List, Tuple -import typing_extensions as tpe import numpy as np import torch +import typing_extensions as tpe from pytorch_lightning import Trainer from torch import nn @@ -165,6 +165,7 @@ class SASRecModelConfig(TransformerModelConfig): data_preparator_type: SessionEncoderDataPreparatorType = SASRecDataPreparator transformer_layers_type: TransformerLayersType = SASRecTransformerLayers + use_causal_attn: bool = True class SASRecModel(TransformerModelBase[SASRecModelConfig]): @@ -313,7 +314,7 @@ def _init_data_preparator(self) -> None: train_min_user_interactions=self.train_min_user_interactions, ) - def _get_config(self) -> TransformerModelConfig: + def _get_config(self) -> SASRecModelConfig: return SASRecModelConfig( cls=self.__class__, n_blocks=self.n_blocks, @@ -345,7 +346,7 @@ def _get_config(self) -> TransformerModelConfig: ) @classmethod - def _from_config(cls, config: TransformerModelConfig) -> tpe.Self: + def _from_config(cls, config: SASRecModelConfig) -> tpe.Self: return cls( trainer=None, n_blocks=config.n_blocks, diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index cf27a5be..7e4d9652 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -401,7 +401,7 @@ def _xavier_normal_init(self) -> None: torch.nn.init.xavier_normal_(param.data) -# #### -------------- Transformer Config -------------- #### # +# #### -------------- Transformer Model Config -------------- #### # def _get_class_obj(spec: tp.Any) -> tp.Any: @@ -468,16 +468,20 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Sequence[str]: ), ] +SessionEncoderDataPreparatorType_T = tp.TypeVar( + "SessionEncoderDataPreparatorType_T", bound=SessionEncoderDataPreparatorType +) + class TransformerModelConfig(ModelConfig): """Transformer model base config.""" - data_preparator_type: SessionEncoderDataPreparatorType + data_preparator_type: SessionEncoderDataPreparatorType_T n_blocks: int = 1 n_heads: int = 1 n_factors: int = 128 use_pos_emb: bool = True - use_causal_attn: bool = True + use_causal_attn: bool = False use_key_padding_mask: bool = False dropout_rate: float = 0.2 session_max_len: int = 32 @@ -525,7 +529,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals n_heads: int = 1, n_factors: int = 128, use_pos_emb: bool = True, - use_causal_attn: bool = True, + use_causal_attn: bool = False, use_key_padding_mask: bool = False, dropout_rate: float = 0.2, session_max_len: int = 32, diff --git a/rectools/models/nn/transformer_data_preparator.py b/rectools/models/nn/transformer_data_preparator.py index fd45434c..243df676 100644 --- a/rectools/models/nn/transformer_data_preparator.py +++ b/rectools/models/nn/transformer_data_preparator.py @@ -110,6 +110,7 @@ def __init__( shuffle_train: bool = True, train_min_user_interactions: int = 2, n_negatives: tp.Optional[int] = None, + **kwargs: tp.Any, ) -> None: """TODO""" self.item_id_map: IdMap diff --git a/tests/models/nn/test_bertrec.py b/tests/models/nn/test_bertrec.py new file mode 100644 index 00000000..e0a8d8e1 --- /dev/null +++ b/tests/models/nn/test_bertrec.py @@ -0,0 +1,238 @@ +# Copyright 2024 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import numpy as np +import pandas as pd +import pytest +import torch +from pytorch_lightning import seed_everything + +from rectools.models import BERT4RecModel +from rectools.models.nn.bert4rec import BERT4RecDataPreparator +from rectools.models.nn.item_net import IdEmbeddingsItemNet +from rectools.models.nn.transformer_base import ( + LearnableInversePositionalEncoding, + PreLNTransformerLayers, + SessionEncoderLightningModule, +) +from tests.models.data import DATASET +from tests.models.utils import assert_default_config_and_default_model_params_are_the_same + +# TODO: add tests with BCE and GBCE (they can be broken for the model when softmax is ok => we need happy path test) + + +class TestBERT4RecModelConfiguration: + def setup_method(self) -> None: + self._seed_everything() + + def _seed_everything(self) -> None: + torch.use_deterministic_algorithms(True) + seed_everything(32, workers=True) + + def test_from_config(self) -> None: + config = { + "n_blocks": 2, + "n_heads": 4, + "n_factors": 64, + "use_pos_emb": False, + "use_causal_attn": False, + "use_key_padding_mask": True, + "dropout_rate": 0.5, + "session_max_len": 10, + "dataloader_num_workers": 5, + "batch_size": 1024, + "loss": "softmax", + "n_negatives": 10, + "gbce_t": 0.5, + "lr": 0.001, + "epochs": 10, + "verbose": 1, + "deterministic": True, + "recommend_device": "auto", + "recommend_n_threads": 0, + "recommend_use_gpu_ranking": True, + "train_min_user_interactions": 5, + "item_net_block_types": (IdEmbeddingsItemNet,), + "pos_encoding_type": LearnableInversePositionalEncoding, + "transformer_layers_type": PreLNTransformerLayers, + "data_preparator_type": BERT4RecDataPreparator, + "lightning_module_type": SessionEncoderLightningModule, + "mask_prob": 0.15, + } + model = BERT4RecModel.from_config(config) + assert model.n_blocks == 2 + assert model.n_heads == 4 + assert model.n_factors == 64 + assert model.use_pos_emb is False + assert model.use_causal_attn is False + assert model.use_key_padding_mask is True + assert model.dropout_rate == 0.5 + assert model.session_max_len == 10 + assert model.dataloader_num_workers == 5 + assert model.batch_size == 1024 + assert model.loss == "softmax" + assert model.n_negatives == 10 + assert model.gbce_t == 0.5 + assert model.lr == 0.001 + assert model.epochs == 10 + assert model.verbose == 1 + assert model.deterministic is True + assert model.recommend_device == "auto" + assert model.recommend_n_threads == 0 + assert model.recommend_use_gpu_ranking is True + assert model.train_min_user_interactions == 5 + assert model._trainer is not None + assert model.item_net_block_types == (IdEmbeddingsItemNet,) + assert model.pos_encoding_type == LearnableInversePositionalEncoding + assert model.transformer_layers_type == PreLNTransformerLayers + assert model.data_preparator_type == BERT4RecDataPreparator + assert model.lightning_module_type == SessionEncoderLightningModule + assert model.mask_prob == 0.15 + + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config(self, simple_types: bool) -> None: + model = BERT4RecModel( + n_blocks=2, + n_heads=4, + n_factors=64, + use_pos_emb=False, + use_causal_attn=False, + use_key_padding_mask=True, + dropout_rate=0.5, + session_max_len=10, + dataloader_num_workers=5, + batch_size=1024, + loss="softmax", + n_negatives=10, + gbce_t=0.5, + lr=0.001, + epochs=10, + verbose=1, + deterministic=True, + recommend_device="auto", + recommend_n_threads=0, + recommend_use_gpu_ranking=True, + train_min_user_interactions=5, + item_net_block_types=(IdEmbeddingsItemNet,), + pos_encoding_type=LearnableInversePositionalEncoding, + transformer_layers_type=PreLNTransformerLayers, + data_preparator_type=BERT4RecDataPreparator, + lightning_module_type=SessionEncoderLightningModule, + mask_prob=0.15, + ) + config = model.get_config(simple_types=simple_types) + expected = { + "cls": "BERT4RecModel" if simple_types else BERT4RecModel, + "n_blocks": 2, + "n_heads": 4, + "n_factors": 64, + "use_pos_emb": False, + "use_causal_attn": False, + "use_key_padding_mask": True, + "dropout_rate": 0.5, + "session_max_len": 10, + "dataloader_num_workers": 5, + "batch_size": 1024, + "loss": "softmax", + "n_negatives": 10, + "gbce_t": 0.5, + "lr": 0.001, + "epochs": 10, + "verbose": 1, + "deterministic": True, + "recommend_device": "auto", + "recommend_n_threads": 0, + "recommend_use_gpu_ranking": True, + "train_min_user_interactions": 5, + "item_net_block_types": ( + ["rectools.models.nn.item_net.IdEmbeddingsItemNet"] if simple_types else (IdEmbeddingsItemNet,) + ), + "pos_encoding_type": ( + "rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding" + if simple_types + else LearnableInversePositionalEncoding + ), + "transformer_layers_type": ( + "rectools.models.nn.transformer_net_blocks.PreLNTransformerLayers" + if simple_types + else PreLNTransformerLayers + ), + "data_preparator_type": ( + "rectools.models.nn.bert4rec.BERT4RecDataPreparator" if simple_types else BERT4RecDataPreparator + ), + "lightning_module_type": ( + "rectools.models.nn.transformer_base.SessionEncoderLightningModule" + if simple_types + else SessionEncoderLightningModule + ), + "mask_prob": 0.15, + } + assert config == expected + + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: + initial_config = { + "n_blocks": 1, + "n_heads": 1, + "n_factors": 10, + "use_pos_emb": False, + "use_causal_attn": False, + "use_key_padding_mask": True, + "dropout_rate": 0.5, + "session_max_len": 5, + "dataloader_num_workers": 1, + "batch_size": 100, + "loss": "softmax", + "n_negatives": 4, + "gbce_t": 0.5, + "lr": 0.001, + "epochs": 1, + "verbose": 0, + "deterministic": True, + "recommend_device": "auto", + "recommend_n_threads": 0, + "recommend_use_gpu_ranking": True, + "train_min_user_interactions": 2, + "item_net_block_types": (IdEmbeddingsItemNet,), + "pos_encoding_type": LearnableInversePositionalEncoding, + "transformer_layers_type": PreLNTransformerLayers, + "data_preparator_type": BERT4RecDataPreparator, + "lightning_module_type": SessionEncoderLightningModule, + "mask_prob": 0.15, + } + + dataset = DATASET + model = BERT4RecModel + + def get_reco(model: BERT4RecModel) -> pd.DataFrame: + return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) + + model_1 = model.from_config(initial_config) + reco_1 = get_reco(model_1) + config_1 = model_1.get_config(simple_types=simple_types) + + self._seed_everything() + model_2 = model.from_config(config_1) + reco_2 = get_reco(model_2) + config_2 = model_2.get_config(simple_types=simple_types) + + assert config_1 == config_2 + pd.testing.assert_frame_equal(reco_1, reco_2) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, int] = {} + model = BERT4RecModel() + assert_default_config_and_default_model_params_are_the_same(model, default_config) From 976c14a72b9c71122ed406a9278c259c931e6969 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 22 Jan 2025 14:57:53 +0300 Subject: [PATCH 03/25] all work --- rectools/models/nn/bert4rec.py | 22 ++++------------------ rectools/models/nn/transformer_base.py | 5 +++-- tests/models/nn/test_bertrec.py | 2 +- tests/models/nn/test_sasrec.py | 2 +- 4 files changed, 9 insertions(+), 22 deletions(-) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index a62eea5a..175b8e7c 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -18,20 +18,17 @@ import numpy as np import torch import typing_extensions as tpe -from pydantic import BeforeValidator, PlainSerializer from pytorch_lightning import Trainer from pytorch_lightning.accelerators import Accelerator -from rectools.utils.misc import get_class_or_function_full_path - from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase -from .transformer_base import ( # SessionEncoderDataPreparatorType_T, +from .transformer_base import ( PADDING_VALUE, + SessionEncoderDataPreparatorType, SessionEncoderLightningModule, SessionEncoderLightningModuleBase, TransformerModelBase, TransformerModelConfig, - _get_class_obj, ) from .transformer_data_preparator import SessionEncoderDataPreparatorBase from .transformer_net_blocks import ( @@ -119,21 +116,10 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> D return {"x": torch.LongTensor(x)} -BERT4RecDataPreparatorType = tpe.Annotated[ - tp.Type[BERT4RecDataPreparator], - BeforeValidator(_get_class_obj), - PlainSerializer( - func=get_class_or_function_full_path, - return_type=str, - when_used="json", - ), -] - - class BERT4RecModelConfig(TransformerModelConfig): """BERT4RecModel config.""" - data_preparator_type: BERT4RecDataPreparatorType = BERT4RecDataPreparator + data_preparator_type: SessionEncoderDataPreparatorType = BERT4RecDataPreparator use_key_padding_mask: bool = True mask_prob: float = 0.15 @@ -244,7 +230,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers, - data_preparator_type: tp.Type[BERT4RecDataPreparator] = BERT4RecDataPreparator, + data_preparator_type: tp.Type[SessionEncoderDataPreparatorBase] = BERT4RecDataPreparator, lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, ): self.mask_prob = mask_prob diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 7e4d9652..5ccc0836 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -476,7 +476,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Sequence[str]: class TransformerModelConfig(ModelConfig): """Transformer model base config.""" - data_preparator_type: SessionEncoderDataPreparatorType_T + data_preparator_type: SessionEncoderDataPreparatorType n_blocks: int = 1 n_heads: int = 1 n_factors: int = 128 @@ -523,7 +523,7 @@ class TransformerModelBase(ModelBase[TransformerModelConfig_T]): # pylint: disa def __init__( # pylint: disable=too-many-arguments, too-many-locals self, - data_preparator_type: tp.Type[SessionEncoderDataPreparatorBase], + data_preparator_type: SessionEncoderDataPreparatorType, # tp.Type[SessionEncoderDataPreparatorBase], transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers, n_blocks: int = 1, n_heads: int = 1, @@ -581,6 +581,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self._init_torch_model() self._init_data_preparator() + if trainer is None: self._init_trainer() else: diff --git a/tests/models/nn/test_bertrec.py b/tests/models/nn/test_bertrec.py index e0a8d8e1..0d8a7132 100644 --- a/tests/models/nn/test_bertrec.py +++ b/tests/models/nn/test_bertrec.py @@ -94,7 +94,7 @@ def test_from_config(self) -> None: assert model.recommend_n_threads == 0 assert model.recommend_use_gpu_ranking is True assert model.train_min_user_interactions == 5 - assert model._trainer is not None + assert model._trainer is not None # pylint: disable = protected-access assert model.item_net_block_types == (IdEmbeddingsItemNet,) assert model.pos_encoding_type == LearnableInversePositionalEncoding assert model.transformer_layers_type == PreLNTransformerLayers diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index 375f1ce4..95c921cd 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -717,7 +717,7 @@ def test_from_config(self) -> None: assert model.recommend_n_threads == 0 assert model.recommend_use_gpu_ranking is True assert model.train_min_user_interactions == 5 - assert model._trainer is not None + assert model._trainer is not None # pylint: disable = protected-access assert model.item_net_block_types == (IdEmbeddingsItemNet,) assert model.pos_encoding_type == LearnableInversePositionalEncoding assert model.transformer_layers_type == SASRecTransformerLayers From 903b27fd0db6a61a7cdbd115c157600b2b98eda4 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Wed, 22 Jan 2025 15:43:10 +0300 Subject: [PATCH 04/25] default params --- rectools/models/nn/bert4rec.py | 10 +++++----- rectools/models/nn/sasrec.py | 10 +++++----- rectools/models/nn/transformer_base.py | 20 ++++++++++---------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index 175b8e7c..7f0cfef6 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -204,9 +204,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): def __init__( # pylint: disable=too-many-arguments, too-many-locals self, - n_blocks: int = 1, - n_heads: int = 1, - n_factors: int = 128, + n_blocks: int = 2, + n_heads: int = 4, + n_factors: int = 256, use_pos_emb: bool = True, use_causal_attn: bool = False, use_key_padding_mask: bool = True, @@ -217,12 +217,12 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals recommend_device: Union[str, Accelerator] = "auto", recommend_n_threads: int = 0, recommend_use_gpu_ranking: bool = True, - session_max_len: int = 32, + session_max_len: int = 100, n_negatives: int = 1, batch_size: int = 128, loss: str = "softmax", gbce_t: float = 0.2, - lr: float = 0.01, + lr: float = 0.001, dataloader_num_workers: int = 0, train_min_user_interactions: int = 2, mask_prob: float = 0.15, diff --git a/rectools/models/nn/sasrec.py b/rectools/models/nn/sasrec.py index 7744d446..dcb2e89c 100644 --- a/rectools/models/nn/sasrec.py +++ b/rectools/models/nn/sasrec.py @@ -246,20 +246,20 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): def __init__( # pylint: disable=too-many-arguments, too-many-locals self, - n_blocks: int = 1, - n_heads: int = 1, - n_factors: int = 128, + n_blocks: int = 2, + n_heads: int = 4, + n_factors: int = 256, use_pos_emb: bool = True, use_causal_attn: bool = True, use_key_padding_mask: bool = False, dropout_rate: float = 0.2, - session_max_len: int = 32, + session_max_len: int = 100, dataloader_num_workers: int = 0, batch_size: int = 128, loss: str = "softmax", n_negatives: int = 1, gbce_t: float = 0.2, - lr: float = 0.01, + lr: float = 0.001, epochs: int = 3, verbose: int = 0, deterministic: bool = False, diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 5ccc0836..4421bc08 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -477,20 +477,20 @@ class TransformerModelConfig(ModelConfig): """Transformer model base config.""" data_preparator_type: SessionEncoderDataPreparatorType - n_blocks: int = 1 - n_heads: int = 1 - n_factors: int = 128 + n_blocks: int = 2 + n_heads: int = 4 + n_factors: int = 256 use_pos_emb: bool = True use_causal_attn: bool = False use_key_padding_mask: bool = False dropout_rate: float = 0.2 - session_max_len: int = 32 + session_max_len: int = 100 dataloader_num_workers: int = 0 batch_size: int = 128 loss: str = "softmax" n_negatives: int = 1 gbce_t: float = 0.2 - lr: float = 0.01 + lr: float = 0.001 epochs: int = 3 verbose: int = 0 deterministic: bool = False @@ -525,20 +525,20 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self, data_preparator_type: SessionEncoderDataPreparatorType, # tp.Type[SessionEncoderDataPreparatorBase], transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers, - n_blocks: int = 1, - n_heads: int = 1, - n_factors: int = 128, + n_blocks: int = 2, + n_heads: int = 4, + n_factors: int = 256, use_pos_emb: bool = True, use_causal_attn: bool = False, use_key_padding_mask: bool = False, dropout_rate: float = 0.2, - session_max_len: int = 32, + session_max_len: int = 100, dataloader_num_workers: int = 0, batch_size: int = 128, loss: str = "softmax", n_negatives: int = 1, gbce_t: float = 0.5, - lr: float = 0.01, + lr: float = 0.001, epochs: int = 3, verbose: int = 0, deterministic: bool = False, From 7b18f5f0f4be20175253ea817adf89450804e822 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Thu, 23 Jan 2025 14:45:59 +0300 Subject: [PATCH 05/25] removed caching item embeddings --- rectools/models/nn/transformer_base.py | 40 +++++++++++++++++++------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 4421bc08..665c8b17 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -382,10 +382,16 @@ def _calc_gbce_loss( loss = self._calc_bce_loss(logits, y, w) return loss - def on_train_end(self) -> None: + def on_predict_epoch_start(self) -> None: """Save item embeddings""" self.eval() - self.item_embs = self.torch_model.item_model.get_all_embeddings() + with torch.no_grad(): + self.item_embs = self.torch_model.item_model.get_all_embeddings() + + def on_predict_epoch_end(self) -> None: + """Clear item embeddings""" + del self.item_embs + torch.cuda.empty_cache() def predict_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """ @@ -674,13 +680,13 @@ def _recommend_u2i( raise ValueError(explanation) user_embs = np.concatenate(session_embs, axis=0) user_embs = user_embs[user_ids] - item_embs = self.lightning_model.item_embs - item_embs_np = item_embs.detach().cpu().numpy() + + item_embs = self.get_item_vectors() ranker = ImplicitRanker( self.u2i_dist, user_embs, # [n_rec_users, n_factors] - item_embs_np, # [n_items + n_item_extra_tokens, n_factors] + item_embs, # [n_items + n_item_extra_tokens, n_factors] ) if filter_viewed: user_items = dataset.get_user_item_matrix(include_weights=False) @@ -688,7 +694,7 @@ def _recommend_u2i( else: ui_csr_for_filter = None - # TODO: When filter_viewed is not needed and user has GPU, torch DOT and topk should be faster + # TODO: We should test if torch `topk`` is faster when `filter_viewed`` is ``False`` user_ids_indices, all_reco_ids, all_scores = ranker.rank( subject_ids=np.arange(user_embs.shape[0]), # n_rec_users k=k, @@ -698,7 +704,21 @@ def _recommend_u2i( use_gpu=self.recommend_use_gpu_ranking and HAS_CUDA, ) all_target_ids = user_ids[user_ids_indices] - return all_target_ids, all_reco_ids, all_scores # n_rec_users, model_internal, scores + return all_target_ids, all_reco_ids, all_scores + + def get_item_vectors(self) -> np.ndarray: + """ + Compute catalog item embeddings through torch model. + + Returns + ------- + np.ndarray + Full catalog item embeddings including extra tokens. + """ + self.torch_model.eval() + with torch.no_grad(): + item_embs = self.torch_model.item_model.get_all_embeddings().detach().cpu().numpy() + return item_embs def _recommend_i2i( self, @@ -710,9 +730,9 @@ def _recommend_i2i( if sorted_item_ids_to_recommend is None: sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() - item_embs = self.lightning_model.item_embs.detach().cpu().numpy() - # TODO: i2i reco do not need filtering viewed and user most of the times has GPU - # We should test if torch `topk`` is be faster + item_embs = self.get_item_vectors() + # TODO: i2i recommendations do not need filtering viewed and user most of the times has GPU + # We should test if torch `topk`` is faster ranker = ImplicitRanker( self.i2i_dist, From c5f04e1b9d4c343079ad74752c395abcbce2540b Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Thu, 23 Jan 2025 17:39:37 +0300 Subject: [PATCH 06/25] configs correct with new args --- rectools/models/nn/bert4rec.py | 36 ++++++++++++++----- rectools/models/nn/sasrec.py | 30 ++++++++++++---- rectools/models/nn/transformer_base.py | 30 +++++++++++++--- .../models/nn/transformer_data_preparator.py | 4 +-- tests/models/nn/test_bertrec.py | 20 ++++++++--- tests/models/nn/test_sasrec.py | 24 +++++++++---- 6 files changed, 110 insertions(+), 34 deletions(-) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index 7f0cfef6..5ea346a9 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -13,13 +13,12 @@ # limitations under the License. import typing as tp -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple import numpy as np import torch import typing_extensions as tpe from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import Accelerator from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase from .transformer_base import ( @@ -170,10 +169,21 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): deterministic : bool, default ``False`` If ``True``, set deterministic algorithms for PyTorch operations. Use `pytorch_lightning.seed_everything` together with this parameter to fix the random state. - recommend_device : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"} or Accelerator, default "auto" - Device for recommend. Used at predict_step of lightning module. + recommend_batch_size : int, default 256 + How many samples per batch to load during `recommend`. + If you want to change this parameter after model is initialized, + you can manually assign new value to model `recommend_batch_size` attribute. + recommend_accelerator : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"}, default "auto" + Accelerator type for `recommend`. Used at predict_step of lightning module. + If you want to change this parameter after model is initialized, + you can manually assign new value to model `recommend_accelerator` attribute. + recommend_devices : int | List[int], default 1 + Devices for `recommend`. Please note that multi-device inference is not supported! + Do not specify more then one device. For ``gpu`` accelerator you can pass which device to + use, e.g. ``[1]``. + Used at predict_step of lightning module. + Multi-device recommendations are not supported. If you want to change this parameter after model is initialized, - you can manually assign new value to model `recommend_device` attribute. recommend_n_threads : int, default 0 Number of threads to use in ranker if GPU ranking is turned off or unavailable. If you want to change this parameter after model is initialized, @@ -214,7 +224,9 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals epochs: int = 3, verbose: int = 0, deterministic: bool = False, - recommend_device: Union[str, Accelerator] = "auto", + recommend_batch_size: int = 256, + recommend_accelerator: str = "auto", + recommend_devices: tp.Union[int, tp.List[int]] = 1, recommend_n_threads: int = 0, recommend_use_gpu_ranking: bool = True, session_max_len: int = 100, @@ -255,7 +267,9 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals epochs=epochs, verbose=verbose, deterministic=deterministic, - recommend_device=recommend_device, + recommend_batch_size=recommend_batch_size, + recommend_accelerator=recommend_accelerator, + recommend_devices=recommend_devices, recommend_n_threads=recommend_n_threads, recommend_use_gpu_ranking=recommend_use_gpu_ranking, train_min_user_interactions=train_min_user_interactions, @@ -296,7 +310,9 @@ def _get_config(self) -> BERT4RecModelConfig: epochs=self.epochs, verbose=self.verbose, deterministic=self.deterministic, - recommend_device=self.recommend_device, + recommend_devices=self.recommend_devices, + recommend_accelerator=self.recommend_accelerator, + recommend_batch_size=self.recommend_batch_size, recommend_n_threads=self.recommend_n_threads, recommend_use_gpu_ranking=self.recommend_use_gpu_ranking, train_min_user_interactions=self.train_min_user_interactions, @@ -329,7 +345,9 @@ def _from_config(cls, config: BERT4RecModelConfig) -> tpe.Self: epochs=config.epochs, verbose=config.verbose, deterministic=config.deterministic, - recommend_device=config.recommend_device, + recommend_devices=config.recommend_devices, + recommend_accelerator=config.recommend_accelerator, + recommend_batch_size=config.recommend_batch_size, recommend_n_threads=config.recommend_n_threads, recommend_use_gpu_ranking=config.recommend_use_gpu_ranking, train_min_user_interactions=config.train_min_user_interactions, diff --git a/rectools/models/nn/sasrec.py b/rectools/models/nn/sasrec.py index dcb2e89c..40e8c102 100644 --- a/rectools/models/nn/sasrec.py +++ b/rectools/models/nn/sasrec.py @@ -212,8 +212,20 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): deterministic : bool, default ``False`` If ``True``, set deterministic algorithms for PyTorch operations. Use `pytorch_lightning.seed_everything` together with this parameter to fix the random state. - recommend_device : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"} or Accelerator, default "auto" - Device for recommend. Used at predict_step of lightning module. + recommend_batch_size : int, default 256 + How many samples per batch to load during `recommend`. + If you want to change this parameter after model is initialized, + you can manually assign new value to model `recommend_batch_size` attribute. + recommend_accelerator : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"}, default "auto" + Accelerator type for `recommend`. Used at predict_step of lightning module. + If you want to change this parameter after model is initialized, + you can manually assign new value to model `recommend_accelerator` attribute. + recommend_devices : int | List[int], default 1 + Devices for `recommend`. Please note that multi-device inference is not supported! + Do not specify more then one device. For ``gpu`` accelerator you can pass which device to + use, e.g. ``[1]``. + Used at predict_step of lightning module. + Multi-device recommendations are not supported. If you want to change this parameter after model is initialized, you can manually assign new value to model `recommend_device` attribute. recommend_n_threads : int, default 0 @@ -263,7 +275,9 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals epochs: int = 3, verbose: int = 0, deterministic: bool = False, - recommend_device: str = "auto", + recommend_batch_size: int = 256, + recommend_accelerator: str = "auto", + recommend_devices: tp.Union[int, tp.List[int]] = 1, recommend_n_threads: int = 0, recommend_use_gpu_ranking: bool = True, train_min_user_interactions: int = 2, @@ -294,7 +308,9 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals epochs=epochs, verbose=verbose, deterministic=deterministic, - recommend_device=recommend_device, + recommend_batch_size=recommend_batch_size, + recommend_accelerator=recommend_accelerator, + recommend_devices=recommend_devices, recommend_n_threads=recommend_n_threads, recommend_use_gpu_ranking=recommend_use_gpu_ranking, train_min_user_interactions=train_min_user_interactions, @@ -334,7 +350,9 @@ def _get_config(self) -> SASRecModelConfig: epochs=self.epochs, verbose=self.verbose, deterministic=self.deterministic, - recommend_device=self.recommend_device, + recommend_devices=self.recommend_devices, + recommend_accelerator=self.recommend_accelerator, + recommend_batch_size=self.recommend_batch_size, recommend_n_threads=self.recommend_n_threads, recommend_use_gpu_ranking=self.recommend_use_gpu_ranking, train_min_user_interactions=self.train_min_user_interactions, @@ -366,7 +384,7 @@ def _from_config(cls, config: SASRecModelConfig) -> tpe.Self: epochs=config.epochs, verbose=config.verbose, deterministic=config.deterministic, - recommend_device=config.recommend_device, + recommend_devices=config.recommend_devices, recommend_n_threads=config.recommend_n_threads, recommend_use_gpu_ranking=config.recommend_use_gpu_ranking, train_min_user_interactions=config.train_min_user_interactions, diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 665c8b17..ea5cd187 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -500,7 +500,9 @@ class TransformerModelConfig(ModelConfig): epochs: int = 3 verbose: int = 0 deterministic: bool = False - recommend_device: str = "auto" # custom accelerators not supported + recommend_batch_size: int = 256 + recommend_accelerator: str = "auto" + recommend_devices: tp.Union[int, tp.List[int]] = 1 recommend_n_threads: int = 0 recommend_use_gpu_ranking: bool = True train_min_user_interactions: int = 2 @@ -548,7 +550,9 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals epochs: int = 3, verbose: int = 0, deterministic: bool = False, - recommend_device: str = "auto", + recommend_batch_size: int = 256, + recommend_accelerator: str = "auto", + recommend_devices: tp.Union[int, tp.List[int]] = 1, recommend_n_threads: int = 0, recommend_use_gpu_ranking: bool = True, train_min_user_interactions: int = 2, @@ -559,6 +563,9 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals **kwargs: tp.Any, ) -> None: super().__init__(verbose=verbose) + + self._check_devices(recommend_devices) + self.transformer_layers_type = transformer_layers_type self.data_preparator_type = data_preparator_type self.n_blocks = n_blocks @@ -577,7 +584,9 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.lr = lr self.epochs = epochs self.deterministic = deterministic - self.recommend_device = recommend_device + self.recommend_batch_size = recommend_batch_size + self.recommend_accelerator = recommend_accelerator + self.recommend_devices = recommend_devices self.recommend_n_threads = recommend_n_threads self.recommend_use_gpu_ranking = recommend_use_gpu_ranking self.train_min_user_interactions = train_min_user_interactions @@ -597,6 +606,12 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.data_preparator: SessionEncoderDataPreparatorBase self.fit_trainer: Trainer + def _check_devices(self, recommend_devices: tp.Union[int, tp.List[int]]) -> None: + if isinstance(recommend_devices, int) and recommend_devices != 1: + raise ValueError("Only single device is supported for inference") + if isinstance(recommend_devices, list) and len(recommend_devices) > 1: + raise ValueError("Only single device is supported for inference") + def _init_data_preparator(self) -> None: raise NotImplementedError() @@ -661,6 +676,10 @@ def _custom_transform_dataset_i2i( ) -> Dataset: return self.data_preparator.transform_dataset_i2i(dataset) + def _init_recommend_trainer(self) -> Trainer: + self._check_devices(self.recommend_devices) + return Trainer(devices=self.recommend_devices, accelerator=self.recommend_accelerator) + def _recommend_u2i( self, user_ids: InternalIdsArray, @@ -672,8 +691,9 @@ def _recommend_u2i( if sorted_item_ids_to_recommend is None: sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() # model internal - recommend_trainer = Trainer(devices=1, accelerator=self.recommend_device) - recommend_dataloader = self.data_preparator.get_dataloader_recommend(dataset) + recommend_trainer = self._init_recommend_trainer() + recommend_dataloader = self.data_preparator.get_dataloader_recommend(dataset, self.recommend_batch_size) + session_embs = recommend_trainer.predict(model=self.lightning_model, dataloaders=recommend_dataloader) if session_embs is None: explanation = """Received empty recommendations.""" diff --git a/rectools/models/nn/transformer_data_preparator.py b/rectools/models/nn/transformer_data_preparator.py index 243df676..bf216489 100644 --- a/rectools/models/nn/transformer_data_preparator.py +++ b/rectools/models/nn/transformer_data_preparator.py @@ -213,7 +213,7 @@ def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: ) return train_dataloader - def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: + def get_dataloader_recommend(self, dataset: Dataset, batch_size: int) -> DataLoader: """TODO""" # Recommend dataloader should return interactions sorted by user ids. # User ids here are internal user ids in dataset.interactions.df that was prepared for recommendations. @@ -222,7 +222,7 @@ def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: sequence_dataset = SequenceDataset.from_interactions(interactions=dataset.interactions.df, sort_users=True) recommend_dataloader = DataLoader( sequence_dataset, - batch_size=self.batch_size, + batch_size=batch_size, collate_fn=self._collate_fn_recommend, num_workers=self.dataloader_num_workers, shuffle=False, diff --git a/tests/models/nn/test_bertrec.py b/tests/models/nn/test_bertrec.py index 0d8a7132..8e637254 100644 --- a/tests/models/nn/test_bertrec.py +++ b/tests/models/nn/test_bertrec.py @@ -61,7 +61,9 @@ def test_from_config(self) -> None: "epochs": 10, "verbose": 1, "deterministic": True, - "recommend_device": "auto", + "recommend_accelerator": "auto", + "recommend_devices": 1, + "recommend_batch_size": 256, "recommend_n_threads": 0, "recommend_use_gpu_ranking": True, "train_min_user_interactions": 5, @@ -90,7 +92,9 @@ def test_from_config(self) -> None: assert model.epochs == 10 assert model.verbose == 1 assert model.deterministic is True - assert model.recommend_device == "auto" + assert model.recommend_accelerator == "auto" + assert model.recommend_devices == 1 + assert model.recommend_batch_size == 256 assert model.recommend_n_threads == 0 assert model.recommend_use_gpu_ranking is True assert model.train_min_user_interactions == 5 @@ -122,7 +126,9 @@ def test_get_config(self, simple_types: bool) -> None: epochs=10, verbose=1, deterministic=True, - recommend_device="auto", + recommend_accelerator="auto", + recommend_devices=1, + recommend_batch_size=256, recommend_n_threads=0, recommend_use_gpu_ranking=True, train_min_user_interactions=5, @@ -153,7 +159,9 @@ def test_get_config(self, simple_types: bool) -> None: "epochs": 10, "verbose": 1, "deterministic": True, - "recommend_device": "auto", + "recommend_accelerator": "auto", + "recommend_devices": 1, + "recommend_batch_size": 256, "recommend_n_threads": 0, "recommend_use_gpu_ranking": True, "train_min_user_interactions": 5, @@ -202,7 +210,9 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "epochs": 1, "verbose": 0, "deterministic": True, - "recommend_device": "auto", + "recommend_accelerator": "auto", + "recommend_devices": 1, + "recommend_batch_size": 256, "recommend_n_threads": 0, "recommend_use_gpu_ranking": True, "train_min_user_interactions": 2, diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index 95c921cd..0c38cf1f 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -196,7 +196,7 @@ def test_u2i( batch_size=4, epochs=2, deterministic=True, - recommend_device=recommend_device, + recommend_accelerator=recommend_device, item_net_block_types=(IdEmbeddingsItemNet,), trainer=trainer, ) @@ -652,7 +652,7 @@ def test_get_dataloader_recommend( ) -> None: data_preparator.process_dataset_train(dataset) dataset = data_preparator.transform_dataset_i2i(dataset) - dataloader = data_preparator.get_dataloader_recommend(dataset) + dataloader = data_preparator.get_dataloader_recommend(dataset, 4) actual = next(iter(dataloader)) for key, value in actual.items(): assert torch.equal(value, recommend_batch[key]) @@ -685,7 +685,9 @@ def test_from_config(self) -> None: "epochs": 10, "verbose": 1, "deterministic": True, - "recommend_device": "auto", + "recommend_accelerator": "auto", + "recommend_devices": 1, + "recommend_batch_size": 256, "recommend_n_threads": 0, "recommend_use_gpu_ranking": True, "train_min_user_interactions": 5, @@ -713,7 +715,9 @@ def test_from_config(self) -> None: assert model.epochs == 10 assert model.verbose == 1 assert model.deterministic is True - assert model.recommend_device == "auto" + assert model.recommend_accelerator == "auto" + assert model.recommend_devices == 1 + assert model.recommend_batch_size == 256 assert model.recommend_n_threads == 0 assert model.recommend_use_gpu_ranking is True assert model.train_min_user_interactions == 5 @@ -744,7 +748,9 @@ def test_get_config(self, simple_types: bool) -> None: epochs=10, verbose=1, deterministic=True, - recommend_device="auto", + recommend_accelerator="auto", + recommend_devices=1, + recommend_batch_size=256, recommend_n_threads=0, recommend_use_gpu_ranking=True, train_min_user_interactions=5, @@ -774,7 +780,9 @@ def test_get_config(self, simple_types: bool) -> None: "epochs": 10, "verbose": 1, "deterministic": True, - "recommend_device": "auto", + "recommend_accelerator": "auto", + "recommend_devices": 1, + "recommend_batch_size": 256, "recommend_n_threads": 0, "recommend_use_gpu_ranking": True, "train_min_user_interactions": 5, @@ -820,7 +828,9 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "epochs": 1, "verbose": 0, "deterministic": True, - "recommend_device": "auto", + "recommend_accelerator": "auto", + "recommend_devices": 1, + "recommend_batch_size": 256, "recommend_n_threads": 0, "recommend_use_gpu_ranking": True, "train_min_user_interactions": 2, From fca4ac0f8bf006445bda1dcc1be4b35aa412f4ba Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Fri, 24 Jan 2025 16:29:49 +0300 Subject: [PATCH 07/25] calc custom loss --- rectools/models/nn/transformer_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index ea5cd187..794458ea 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -305,7 +305,9 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to negatives = batch["negatives"] logits = self._get_pos_neg_logits(x, y, negatives) return self._calc_gbce_loss(logits, y, w, negatives) + return self._calc_custom_loss(batch, batch_idx) + def _calc_custom_loss(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: raise ValueError(f"loss {self.loss} is not supported") def _get_full_catalog_logits(self, x: torch.Tensor) -> torch.Tensor: From a62e5358d3e0fdd9b01d6161060339f5ff4d6cd3 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Fri, 24 Jan 2025 17:04:18 +0300 Subject: [PATCH 08/25] fixed and added get_val_mask_func --- rectools/models/nn/bert4rec.py | 6 ++---- rectools/models/nn/sasrec.py | 5 ++--- rectools/models/nn/transformer_base.py | 18 +++++++++++++++--- .../models/nn/transformer_data_preparator.py | 1 - tests/models/nn/test_bertrec.py | 8 ++++++++ tests/models/nn/test_sasrec.py | 5 +++++ 6 files changed, 32 insertions(+), 11 deletions(-) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index 72732f5f..e9917aeb 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -13,11 +13,9 @@ # limitations under the License. import typing as tp - from collections.abc import Hashable from typing import Dict, List, Tuple, Union - import numpy as np import torch import typing_extensions as tpe @@ -359,7 +357,7 @@ def _get_config(self) -> BERT4RecModelConfig: transformer_layers_type=self.transformer_layers_type, data_preparator_type=self.data_preparator_type, lightning_module_type=self.lightning_module_type, - mask_prob=self.mask_prob + mask_prob=self.mask_prob, get_val_mask_func=self.get_val_mask_func, ) @@ -396,5 +394,5 @@ def _from_config(cls, config: BERT4RecModelConfig) -> tpe.Self: data_preparator_type=config.data_preparator_type, lightning_module_type=config.lightning_module_type, mask_prob=config.mask_prob, - get_val_mask_func=self.get_val_mask_func, + get_val_mask_func=config.get_val_mask_func, ) diff --git a/rectools/models/nn/sasrec.py b/rectools/models/nn/sasrec.py index 8d8a4422..e22d0dea 100644 --- a/rectools/models/nn/sasrec.py +++ b/rectools/models/nn/sasrec.py @@ -347,7 +347,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_block_types=item_net_block_types, pos_encoding_type=pos_encoding_type, lightning_module_type=lightning_module_type, - get_val_mask_func=get_val_mask_func + get_val_mask_func=get_val_mask_func, ) def _init_data_preparator(self) -> None: @@ -392,7 +392,7 @@ def _get_config(self) -> SASRecModelConfig: transformer_layers_type=self.transformer_layers_type, data_preparator_type=self.data_preparator_type, lightning_module_type=self.lightning_module_type, - get_val_mask_func=self.get_val_mask_func + get_val_mask_func=self.get_val_mask_func, ) @classmethod @@ -425,6 +425,5 @@ def _from_config(cls, config: SASRecModelConfig) -> tpe.Self: transformer_layers_type=config.transformer_layers_type, data_preparator_type=config.data_preparator_type, lightning_module_type=config.lightning_module_type, - train_min_user_interactions=config.train_min_user_interactions, get_val_mask_func=config.get_val_mask_func, ) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 519ed628..9ea5f9d8 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -327,9 +327,9 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to loss = self._calc_custom_loss(batch, batch_idx) self.log(self.train_loss_name, loss, on_step=False, on_epoch=True, prog_bar=self.verbose > 0) - + return loss - + def _calc_custom_loss(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: raise ValueError(f"loss {self.loss} is not supported") @@ -534,6 +534,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Sequence[str]: "SessionEncoderDataPreparatorType_T", bound=SessionEncoderDataPreparatorType ) +CallableSerialized = tpe.Annotated[ + tp.Callable, + BeforeValidator(_get_class_obj), + PlainSerializer( + func=get_class_or_function_full_path, + return_type=str, + when_used="json", + ), +] + class TransformerModelConfig(ModelConfig): """Transformer model base config.""" @@ -566,6 +576,7 @@ class TransformerModelConfig(ModelConfig): pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding transformer_layers_type: TransformerLayersType = PreLNTransformerLayers lightning_module_type: SessionEncoderLightningModuleType = SessionEncoderLightningModule + get_val_mask_func: tp.Optional[CallableSerialized] = None TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig) @@ -618,6 +629,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet), pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding, lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule, + get_val_mask_func: tp.Optional[tp.Callable] = None, **kwargs: tp.Any, ) -> None: super().__init__(verbose=verbose) @@ -651,6 +663,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.item_net_block_types = item_net_block_types self.pos_encoding_type = pos_encoding_type self.lightning_module_type = lightning_module_type + self.get_val_mask_func = get_val_mask_func self._init_torch_model() self._init_data_preparator() @@ -723,7 +736,6 @@ def _fit( torch_model = deepcopy(self._torch_model) torch_model.construct_item_net(self.data_preparator.train_dataset) - self._init_lightning_model(torch_model, self.data_preparator.n_item_extra_tokens) self.fit_trainer = deepcopy(self._trainer) diff --git a/rectools/models/nn/transformer_data_preparator.py b/rectools/models/nn/transformer_data_preparator.py index 5a886942..f7764e1c 100644 --- a/rectools/models/nn/transformer_data_preparator.py +++ b/rectools/models/nn/transformer_data_preparator.py @@ -231,7 +231,6 @@ def get_dataloader_train(self) -> DataLoader: ) return train_dataloader - def get_dataloader_val(self) -> tp.Optional[DataLoader]: """ Construct validation dataloader from processed dataset. diff --git a/tests/models/nn/test_bertrec.py b/tests/models/nn/test_bertrec.py index 8e637254..626c2773 100644 --- a/tests/models/nn/test_bertrec.py +++ b/tests/models/nn/test_bertrec.py @@ -73,6 +73,7 @@ def test_from_config(self) -> None: "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": SessionEncoderLightningModule, "mask_prob": 0.15, + "get_val_mask_func": None, } model = BERT4RecModel.from_config(config) assert model.n_blocks == 2 @@ -105,6 +106,7 @@ def test_from_config(self) -> None: assert model.data_preparator_type == BERT4RecDataPreparator assert model.lightning_module_type == SessionEncoderLightningModule assert model.mask_prob == 0.15 + assert model.get_val_mask_func is None @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config(self, simple_types: bool) -> None: @@ -138,6 +140,7 @@ def test_get_config(self, simple_types: bool) -> None: data_preparator_type=BERT4RecDataPreparator, lightning_module_type=SessionEncoderLightningModule, mask_prob=0.15, + get_val_mask_func=None, ) config = model.get_config(simple_types=simple_types) expected = { @@ -187,6 +190,7 @@ def test_get_config(self, simple_types: bool) -> None: else SessionEncoderLightningModule ), "mask_prob": 0.15, + "get_val_mask_func": None, } assert config == expected @@ -222,6 +226,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": SessionEncoderLightningModule, "mask_prob": 0.15, + "get_val_mask_func": None, } dataset = DATASET @@ -246,3 +251,6 @@ def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} model = BERT4RecModel() assert_default_config_and_default_model_params_are_the_same(model, default_config) + + +# TODO: test with passed custom callable as `get_val_mask_func` diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index 7c090f6b..086ce9af 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -862,6 +862,7 @@ def test_from_config(self) -> None: "transformer_layers_type": SASRecTransformerLayers, "data_preparator_type": SASRecDataPreparator, "lightning_module_type": SessionEncoderLightningModule, + "get_val_mask_func": None, } model = SASRecModel.from_config(config) assert model.n_blocks == 2 @@ -893,6 +894,7 @@ def test_from_config(self) -> None: assert model.transformer_layers_type == SASRecTransformerLayers assert model.data_preparator_type == SASRecDataPreparator assert model.lightning_module_type == SessionEncoderLightningModule + assert model.get_val_mask_func is None @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config(self, simple_types: bool) -> None: @@ -925,6 +927,7 @@ def test_get_config(self, simple_types: bool) -> None: transformer_layers_type=SASRecTransformerLayers, data_preparator_type=SASRecDataPreparator, lightning_module_type=SessionEncoderLightningModule, + get_val_mask_func=None, ) config = model.get_config(simple_types=simple_types) expected = { @@ -971,6 +974,7 @@ def test_get_config(self, simple_types: bool) -> None: if simple_types else SessionEncoderLightningModule ), + "get_val_mask_func": None, } assert config == expected @@ -1005,6 +1009,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "transformer_layers_type": SASRecTransformerLayers, "data_preparator_type": SASRecDataPreparator, "lightning_module_type": SessionEncoderLightningModule, + "get_val_mask_func": None, } dataset = DATASET From 2ff5fdedddc1dbb99e956bf3d2bfa6355055077a Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Mon, 27 Jan 2025 14:07:33 +0300 Subject: [PATCH 09/25] linter --- rectools/models/nn/bert4rec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index e9917aeb..8b926e17 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -14,7 +14,7 @@ import typing as tp from collections.abc import Hashable -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple import numpy as np import torch From 44be2d65ca9920befbd28b63bbd68e7609cb940e Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Mon, 27 Jan 2025 14:26:23 +0300 Subject: [PATCH 10/25] pyling --- tests/models/nn/test_sasrec.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index 086ce9af..3c6eae04 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: ignore = too-many-lines + import os import typing as tp from functools import partial From a1b15660f466a439ea735dc1038d0f010d4fd119 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Mon, 27 Jan 2025 14:30:55 +0300 Subject: [PATCH 11/25] fixed --- tests/models/nn/test_sasrec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index 3c6eae04..00892ce4 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: ignore = too-many-lines +# pylint: disable=too-many-lines import os import typing as tp From 67248254f04d2449aaf2c783212a87a6a33768b0 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Mon, 27 Jan 2025 15:19:55 +0300 Subject: [PATCH 12/25] updated configs example --- examples/9_model_configs_and_saving.ipynb | 249 +++++++++++++++++++++- 1 file changed, 248 insertions(+), 1 deletion(-) diff --git a/examples/9_model_configs_and_saving.ipynb b/examples/9_model_configs_and_saving.ipynb index 39800a26..a5848faa 100644 --- a/examples/9_model_configs_and_saving.ipynb +++ b/examples/9_model_configs_and_saving.ipynb @@ -29,11 +29,23 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/lightfm/_lightfm_fast.py:9: UserWarning: LightFM was compiled without OpenMP support. Only a single thread will be used.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from datetime import timedelta\n", + "import pandas as pd\n", "\n", "from rectools.models import (\n", + " SASRecModel,\n", + " BERT4RecModel,\n", " ImplicitItemKNNWrapperModel, \n", " ImplicitALSWrapperModel, \n", " ImplicitBPRWrapperModel, \n", @@ -315,6 +327,241 @@ "## Configs examples for all models" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### SASRec" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "data": { + "text/plain": [ + "{'cls': 'SASRecModel',\n", + " 'verbose': 0,\n", + " 'data_preparator_type': 'rectools.models.nn.sasrec.SASRecDataPreparator',\n", + " 'n_blocks': 1,\n", + " 'n_heads': 1,\n", + " 'n_factors': 64,\n", + " 'use_pos_emb': True,\n", + " 'use_causal_attn': True,\n", + " 'use_key_padding_mask': False,\n", + " 'dropout_rate': 0.2,\n", + " 'session_max_len': 100,\n", + " 'dataloader_num_workers': 0,\n", + " 'batch_size': 128,\n", + " 'loss': 'softmax',\n", + " 'n_negatives': 1,\n", + " 'gbce_t': 0.2,\n", + " 'lr': 0.001,\n", + " 'epochs': 2,\n", + " 'deterministic': False,\n", + " 'recommend_batch_size': 256,\n", + " 'recommend_accelerator': 'auto',\n", + " 'recommend_devices': 1,\n", + " 'recommend_n_threads': 0,\n", + " 'recommend_use_gpu_ranking': True,\n", + " 'train_min_user_interactions': 2,\n", + " 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n", + " 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n", + " 'pos_encoding_type': 'rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding',\n", + " 'transformer_layers_type': 'rectools.models.nn.sasrec.SASRecTransformerLayers',\n", + " 'lightning_module_type': 'rectools.models.nn.transformer_base.SessionEncoderLightningModule',\n", + " 'get_val_mask_func': None}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = SASRecModel.from_config({\n", + " \"epochs\": 2,\n", + " \"n_blocks\": 1,\n", + " \"n_heads\": 1,\n", + " \"n_factors\": 64,\n", + "})\n", + "model.get_params(simple_types=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Transformer models (SASRec and BERT4Rec) in RecTools may accept functions and classes as arguments. These types of arguments are fully compatible with RecTools configs. You can eigther pass them as python objects or as strings that define their import paths.\n", + "\n", + "Below is an example of both approaches:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/pydantic/main.py:426: UserWarning: Pydantic serializer warnings:\n", + " Expected `str` but got `tuple` with value `('rectools.models.nn.item...net.CatFeaturesItemNet')` - serialized value may not be as expected\n", + " return self.__pydantic_serializer__.to_python(\n" + ] + }, + { + "data": { + "text/plain": [ + "{'cls': 'SASRecModel',\n", + " 'verbose': 0,\n", + " 'data_preparator_type': 'rectools.models.nn.sasrec.SASRecDataPreparator',\n", + " 'n_blocks': 2,\n", + " 'n_heads': 4,\n", + " 'n_factors': 256,\n", + " 'use_pos_emb': True,\n", + " 'use_causal_attn': True,\n", + " 'use_key_padding_mask': False,\n", + " 'dropout_rate': 0.2,\n", + " 'session_max_len': 100,\n", + " 'dataloader_num_workers': 0,\n", + " 'batch_size': 128,\n", + " 'loss': 'softmax',\n", + " 'n_negatives': 1,\n", + " 'gbce_t': 0.2,\n", + " 'lr': 0.001,\n", + " 'epochs': 3,\n", + " 'deterministic': False,\n", + " 'recommend_batch_size': 256,\n", + " 'recommend_accelerator': 'auto',\n", + " 'recommend_devices': 1,\n", + " 'recommend_n_threads': 0,\n", + " 'recommend_use_gpu_ranking': True,\n", + " 'train_min_user_interactions': 2,\n", + " 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n", + " 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n", + " 'pos_encoding_type': 'rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding',\n", + " 'transformer_layers_type': 'rectools.models.nn.sasrec.SASRecTransformerLayers',\n", + " 'lightning_module_type': 'rectools.models.nn.transformer_base.SessionEncoderLightningModule',\n", + " 'get_val_mask_func': '__main__.leave_one_out_mask'}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:\n", + " rank = (\n", + " interactions\n", + " .sort_values(Columns.Datetime, ascending=False, kind=\"stable\")\n", + " .groupby(Columns.User, sort=False)\n", + " .cumcount()\n", + " )\n", + " return rank == 0\n", + "\n", + "model = SASRecModel.from_config({\n", + " \"get_val_mask_func\": leave_one_out_mask, # function to get validation mask\n", + " \"transformer_layers_type\": \"rectools.models.nn.sasrec.SASRecTransformerLayers\", # path to transformer layers class\n", + "})\n", + "model.get_params(simple_types=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### BERT4Rec" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/pydantic/main.py:426: UserWarning: Pydantic serializer warnings:\n", + " Expected `str` but got `tuple` with value `('rectools.models.nn.item...net.CatFeaturesItemNet')` - serialized value may not be as expected\n", + " return self.__pydantic_serializer__.to_python(\n" + ] + }, + { + "data": { + "text/plain": [ + "{'cls': 'BERT4RecModel',\n", + " 'verbose': 0,\n", + " 'data_preparator_type': 'rectools.models.nn.bert4rec.BERT4RecDataPreparator',\n", + " 'n_blocks': 1,\n", + " 'n_heads': 1,\n", + " 'n_factors': 64,\n", + " 'use_pos_emb': True,\n", + " 'use_causal_attn': False,\n", + " 'use_key_padding_mask': True,\n", + " 'dropout_rate': 0.2,\n", + " 'session_max_len': 100,\n", + " 'dataloader_num_workers': 0,\n", + " 'batch_size': 128,\n", + " 'loss': 'softmax',\n", + " 'n_negatives': 1,\n", + " 'gbce_t': 0.2,\n", + " 'lr': 0.001,\n", + " 'epochs': 2,\n", + " 'deterministic': False,\n", + " 'recommend_batch_size': 256,\n", + " 'recommend_accelerator': 'auto',\n", + " 'recommend_devices': 1,\n", + " 'recommend_n_threads': 0,\n", + " 'recommend_use_gpu_ranking': True,\n", + " 'train_min_user_interactions': 2,\n", + " 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n", + " 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n", + " 'pos_encoding_type': 'rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding',\n", + " 'transformer_layers_type': 'rectools.models.nn.sasrec.SASRecTransformerLayers',\n", + " 'lightning_module_type': 'rectools.models.nn.transformer_base.SessionEncoderLightningModule',\n", + " 'get_val_mask_func': '__main__.leave_one_out_mask',\n", + " 'mask_prob': 0.2}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = BERT4RecModel.from_config({\n", + " \"epochs\": 2,\n", + " \"n_blocks\": 1,\n", + " \"n_heads\": 1,\n", + " \"n_factors\": 64,\n", + " \"mask_prob\": 0.2,\n", + " \"get_val_mask_func\": leave_one_out_mask, # function to get validation mask\n", + " \"transformer_layers_type\": \"rectools.models.nn.sasrec.SASRecTransformerLayers\", # path to transformer layers class\n", + "})\n", + "model.get_params(simple_types=True)" + ] + }, { "cell_type": "markdown", "metadata": {}, From f22a5b5066822ef20fd0c5ecad6d0200bc321b47 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Mon, 27 Jan 2025 15:25:04 +0300 Subject: [PATCH 13/25] added custom validation func to config tests --- rectools/models/nn/transformer_base.py | 2 +- tests/models/nn/test_bertrec.py | 25 ++++++++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 9ea5f9d8..0007c292 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -600,7 +600,7 @@ class TransformerModelBase(ModelBase[TransformerModelConfig_T]): # pylint: disa def __init__( # pylint: disable=too-many-arguments, too-many-locals self, - data_preparator_type: SessionEncoderDataPreparatorType, # tp.Type[SessionEncoderDataPreparatorBase], + data_preparator_type: SessionEncoderDataPreparatorType, transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers, n_blocks: int = 2, n_heads: int = 4, diff --git a/tests/models/nn/test_bertrec.py b/tests/models/nn/test_bertrec.py index 626c2773..79abdc98 100644 --- a/tests/models/nn/test_bertrec.py +++ b/tests/models/nn/test_bertrec.py @@ -20,6 +20,7 @@ import torch from pytorch_lightning import seed_everything +from rectools import Columns from rectools.models import BERT4RecModel from rectools.models.nn.bert4rec import BERT4RecDataPreparator from rectools.models.nn.item_net import IdEmbeddingsItemNet @@ -34,6 +35,15 @@ # TODO: add tests with BCE and GBCE (they can be broken for the model when softmax is ok => we need happy path test) +def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series: + rank = ( + interactions + .sort_values(Columns.Datetime, ascending=False, kind="stable") + .groupby(Columns.User, sort=False) + .cumcount() + ) + return rank == 0 + class TestBERT4RecModelConfiguration: def setup_method(self) -> None: self._seed_everything() @@ -43,6 +53,7 @@ def _seed_everything(self) -> None: seed_everything(32, workers=True) def test_from_config(self) -> None: + config = { "n_blocks": 2, "n_heads": 4, @@ -73,7 +84,7 @@ def test_from_config(self) -> None: "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": SessionEncoderLightningModule, "mask_prob": 0.15, - "get_val_mask_func": None, + "get_val_mask_func": leave_one_out_mask, } model = BERT4RecModel.from_config(config) assert model.n_blocks == 2 @@ -106,7 +117,7 @@ def test_from_config(self) -> None: assert model.data_preparator_type == BERT4RecDataPreparator assert model.lightning_module_type == SessionEncoderLightningModule assert model.mask_prob == 0.15 - assert model.get_val_mask_func is None + assert model.get_val_mask_func == leave_one_out_mask # is None @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config(self, simple_types: bool) -> None: @@ -140,7 +151,7 @@ def test_get_config(self, simple_types: bool) -> None: data_preparator_type=BERT4RecDataPreparator, lightning_module_type=SessionEncoderLightningModule, mask_prob=0.15, - get_val_mask_func=None, + get_val_mask_func=leave_one_out_mask, ) config = model.get_config(simple_types=simple_types) expected = { @@ -190,7 +201,11 @@ def test_get_config(self, simple_types: bool) -> None: else SessionEncoderLightningModule ), "mask_prob": 0.15, - "get_val_mask_func": None, + "get_val_mask_func": ( + "tests.models.nn.test_bertrec.leave_one_out_mask" + if simple_types + else leave_one_out_mask + ) } assert config == expected @@ -226,7 +241,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "data_preparator_type": BERT4RecDataPreparator, "lightning_module_type": SessionEncoderLightningModule, "mask_prob": 0.15, - "get_val_mask_func": None, + "get_val_mask_func": leave_one_out_mask, } dataset = DATASET From a7546d737add67fd0d0ad91b9a52316cea66d168 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Mon, 27 Jan 2025 15:47:03 +0300 Subject: [PATCH 14/25] linters --- tests/models/nn/test_bertrec.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/tests/models/nn/test_bertrec.py b/tests/models/nn/test_bertrec.py index 79abdc98..6ce4e9a7 100644 --- a/tests/models/nn/test_bertrec.py +++ b/tests/models/nn/test_bertrec.py @@ -32,18 +32,16 @@ from tests.models.data import DATASET from tests.models.utils import assert_default_config_and_default_model_params_are_the_same -# TODO: add tests with BCE and GBCE (they can be broken for the model when softmax is ok => we need happy path test) - def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series: rank = ( - interactions - .sort_values(Columns.Datetime, ascending=False, kind="stable") + interactions.sort_values(Columns.Datetime, ascending=False, kind="stable") .groupby(Columns.User, sort=False) .cumcount() ) return rank == 0 + class TestBERT4RecModelConfiguration: def setup_method(self) -> None: self._seed_everything() @@ -53,7 +51,7 @@ def _seed_everything(self) -> None: seed_everything(32, workers=True) def test_from_config(self) -> None: - + config = { "n_blocks": 2, "n_heads": 4, @@ -117,7 +115,7 @@ def test_from_config(self) -> None: assert model.data_preparator_type == BERT4RecDataPreparator assert model.lightning_module_type == SessionEncoderLightningModule assert model.mask_prob == 0.15 - assert model.get_val_mask_func == leave_one_out_mask # is None + assert model.get_val_mask_func is leave_one_out_mask @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config(self, simple_types: bool) -> None: @@ -202,10 +200,8 @@ def test_get_config(self, simple_types: bool) -> None: ), "mask_prob": 0.15, "get_val_mask_func": ( - "tests.models.nn.test_bertrec.leave_one_out_mask" - if simple_types - else leave_one_out_mask - ) + "tests.models.nn.test_bertrec.leave_one_out_mask" if simple_types else leave_one_out_mask + ), } assert config == expected @@ -266,6 +262,3 @@ def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} model = BERT4RecModel() assert_default_config_and_default_model_params_are_the_same(model, default_config) - - -# TODO: test with passed custom callable as `get_val_mask_func` From 268a91a4333d0bcb7df93e00bfbf097d0812ed6c Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Mon, 27 Jan 2025 15:52:11 +0300 Subject: [PATCH 15/25] removed TODO --- rectools/models/nn/bert4rec.py | 4 ++-- rectools/models/nn/item_net.py | 2 -- rectools/models/nn/transformer_base.py | 1 - rectools/models/nn/transformer_data_preparator.py | 4 ++-- tests/models/nn/test_sasrec.py | 3 --- 5 files changed, 4 insertions(+), 10 deletions(-) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index 8b926e17..43c7e59d 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -314,9 +314,9 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals get_val_mask_func=get_val_mask_func, ) - def _init_data_preparator(self) -> None: # TODO: negative losses are not working now + def _init_data_preparator(self) -> None: self.data_preparator: SessionEncoderDataPreparatorBase = self.data_preparator_type( - session_max_len=self.session_max_len, # -1 + session_max_len=self.session_max_len, n_negatives=self.n_negatives if self.loss != "softmax" else None, batch_size=self.batch_size, dataloader_num_workers=self.dataloader_num_workers, diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index 2f87bcd6..fd6af123 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -88,7 +88,6 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: torch.Tensor Item embeddings. """ - # TODO: Should we use torch.nn.EmbeddingBag? feature_dense = self.get_dense_item_features(items) feature_embs = self.category_embeddings(self.feature_catalog.to(self.device)) @@ -252,7 +251,6 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: Item embeddings. """ item_embs = [] - # TODO: Add functionality for parallel computing. for idx_block in range(self.n_item_blocks): item_emb = self.item_net_blocks[idx_block](items) item_embs.append(item_emb) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 0007c292..85aefdde 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -306,7 +306,6 @@ class SessionEncoderLightningModule(SessionEncoderLightningModuleBase): def on_train_start(self) -> None: """Initialize parameters with values from Xavier normal distribution.""" - # TODO: init padding embedding with zeros self._xavier_normal_init() def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: diff --git a/rectools/models/nn/transformer_data_preparator.py b/rectools/models/nn/transformer_data_preparator.py index f7764e1c..43d466a7 100644 --- a/rectools/models/nn/transformer_data_preparator.py +++ b/rectools/models/nn/transformer_data_preparator.py @@ -164,7 +164,7 @@ def process_dataset_train(self, dataset: Dataset) -> None: ) # Construct dataset - # TODO: user features are dropped for now + # User features are dropped for now because model doesn't support them user_id_map = IdMap.from_values(interactions[Columns.User].values) item_id_map = IdMap.from_values(self.item_extra_tokens) item_id_map = item_id_map.add_ids(interactions[Columns.Item]) @@ -314,7 +314,7 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset rec_user_id_map = IdMap.from_values(interactions[Columns.User]) # Construct dataset - # TODO: For now features are dropped because model doesn't support them + # For now features are dropped because model doesn't support them on inference n_filtered = len(users) - rec_user_id_map.size if n_filtered > 0: explanation = f"""{n_filtered} target users were considered cold because of missing known items""" diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index 00892ce4..df1f39fa 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -39,9 +39,6 @@ ) from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal -# TODO: add tests with BCE and GBCE -# TODO: tests for BERT4Rec in a separate file (one loss will be enough) - class TestSASRecModel: def setup_method(self) -> None: From a6056324ba9e6e6e2ec33c9edc59d29a618a5888 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Mon, 27 Jan 2025 15:58:19 +0300 Subject: [PATCH 16/25] compat and serialization tests --- rectools/compat.py | 12 ++++++++++++ tests/models/test_serialization.py | 13 +------------ tests/test_compat.py | 4 ++++ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/rectools/compat.py b/rectools/compat.py index 24abe1f1..c983fbc2 100644 --- a/rectools/compat.py +++ b/rectools/compat.py @@ -40,6 +40,18 @@ class DSSMModel(RequirementUnavailable): requirement = "torch" +class SASRecModel(RequirementUnavailable): + """Dummy class, which is returned if there are no dependencies required for the model""" + + requirement = "torch" + + +class BERT4RecModel(RequirementUnavailable): + """Dummy class, which is returned if there are no dependencies required for the model""" + + requirement = "torch" + + class ItemToItemAnnRecommender(RequirementUnavailable): """Dummy class, which is returned if there are no dependencies required for the model""" diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index 2b212eab..49c55ce2 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -29,7 +29,6 @@ from rectools.metrics import NDCG from rectools.models import ( - BERT4RecModel, DSSMModel, EASEModel, ImplicitALSWrapperModel, @@ -38,7 +37,6 @@ LightFMWrapperModel, PopularInCategoryModel, PopularModel, - SASRecModel, load_model, model_from_config, ) @@ -55,16 +53,7 @@ for cls in get_successors(ModelBase) if (cls.__module__.startswith("rectools.models") and cls not in INTERMEDIATE_MODEL_CLASSES) ) -CONFIGURABLE_MODEL_CLASSES = tuple( - cls - for cls in EXPOSABLE_MODEL_CLASSES - if cls - not in ( - DSSMModel, - SASRecModel, - BERT4RecModel, - ) -) +CONFIGURABLE_MODEL_CLASSES = tuple(cls for cls in EXPOSABLE_MODEL_CLASSES if cls not in (DSSMModel,)) def init_default_model(model_cls: tp.Type[ModelBase]) -> ModelBase: diff --git a/tests/test_compat.py b/tests/test_compat.py index ee128391..4dd53345 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -17,11 +17,13 @@ import pytest from rectools.compat import ( + BERT4RecModel, DSSMModel, ItemToItemAnnRecommender, ItemToItemVisualApp, LightFMWrapperModel, MetricsApp, + SASRecModel, UserToItemAnnRecommender, VisualApp, ) @@ -31,6 +33,8 @@ "model", ( DSSMModel, + SASRecModel, + BERT4RecModel, ItemToItemAnnRecommender, UserToItemAnnRecommender, LightFMWrapperModel, From 582061179e240cd138e4f4799b2d988fc9ad5ebb Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 28 Jan 2025 10:57:53 +0300 Subject: [PATCH 17/25] fixed example --- examples/9_model_configs_and_saving.ipynb | 353 ++++++++++------------ 1 file changed, 152 insertions(+), 201 deletions(-) diff --git a/examples/9_model_configs_and_saving.ipynb b/examples/9_model_configs_and_saving.ipynb index a5848faa..f635bd44 100644 --- a/examples/9_model_configs_and_saving.ipynb +++ b/examples/9_model_configs_and_saving.ipynb @@ -27,18 +27,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 18, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/lightfm/_lightfm_fast.py:9: UserWarning: LightFM was compiled without OpenMP support. Only a single thread will be used.\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "from datetime import timedelta\n", "import pandas as pd\n", @@ -336,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -345,7 +336,10 @@ "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" + "HPU available: False, using: 0 HPUs\n", + "/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/pydantic/main.py:426: UserWarning: Pydantic serializer warnings:\n", + " Expected `str` but got `tuple` with value `('rectools.models.nn.item...net.CatFeaturesItemNet')` - serialized value may not be as expected\n", + " return self.__pydantic_serializer__.to_python(\n" ] }, { @@ -384,18 +378,20 @@ " 'get_val_mask_func': None}" ] }, - "execution_count": 13, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "model = SASRecModel.from_config({\n", + "config = {\n", " \"epochs\": 2,\n", " \"n_blocks\": 1,\n", " \"n_heads\": 1,\n", - " \"n_factors\": 64,\n", - "})\n", + " \"n_factors\": 64, \n", + "}\n", + "\n", + "model = SASRecModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -410,7 +406,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -419,10 +415,7 @@ "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/pydantic/main.py:426: UserWarning: Pydantic serializer warnings:\n", - " Expected `str` but got `tuple` with value `('rectools.models.nn.item...net.CatFeaturesItemNet')` - serialized value may not be as expected\n", - " return self.__pydantic_serializer__.to_python(\n" + "HPU available: False, using: 0 HPUs\n" ] }, { @@ -461,7 +454,7 @@ " 'get_val_mask_func': '__main__.leave_one_out_mask'}" ] }, - "execution_count": 15, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -476,10 +469,12 @@ " )\n", " return rank == 0\n", "\n", - "model = SASRecModel.from_config({\n", + "config = {\n", " \"get_val_mask_func\": leave_one_out_mask, # function to get validation mask\n", " \"transformer_layers_type\": \"rectools.models.nn.sasrec.SASRecTransformerLayers\", # path to transformer layers class\n", - "})\n", + "}\n", + "\n", + "model = SASRecModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -492,7 +487,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -501,10 +496,7 @@ "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/pydantic/main.py:426: UserWarning: Pydantic serializer warnings:\n", - " Expected `str` but got `tuple` with value `('rectools.models.nn.item...net.CatFeaturesItemNet')` - serialized value may not be as expected\n", - " return self.__pydantic_serializer__.to_python(\n" + "HPU available: False, using: 0 HPUs\n" ] }, { @@ -538,27 +530,29 @@ " 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n", " 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n", " 'pos_encoding_type': 'rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding',\n", - " 'transformer_layers_type': 'rectools.models.nn.sasrec.SASRecTransformerLayers',\n", + " 'transformer_layers_type': 'rectools.models.nn.transformer_net_blocks.PreLNTransformerLayers',\n", " 'lightning_module_type': 'rectools.models.nn.transformer_base.SessionEncoderLightningModule',\n", " 'get_val_mask_func': '__main__.leave_one_out_mask',\n", " 'mask_prob': 0.2}" ] }, - "execution_count": 17, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "model = BERT4RecModel.from_config({\n", + "config = {\n", " \"epochs\": 2,\n", " \"n_blocks\": 1,\n", " \"n_heads\": 1,\n", " \"n_factors\": 64,\n", " \"mask_prob\": 0.2,\n", " \"get_val_mask_func\": leave_one_out_mask, # function to get validation mask\n", - " \"transformer_layers_type\": \"rectools.models.nn.sasrec.SASRecTransformerLayers\", # path to transformer layers class\n", - "})\n", + " \"transformer_layers_type\": \"rectools.models.nn.transformer_base.PreLNTransformerLayers\", # path to transformer layers class\n", + "}\n", + "\n", + "model = BERT4RecModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -582,22 +576,7 @@ }, { "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [], - "source": [ - "model = ImplicitItemKNNWrapperModel.from_config({\n", - " \"model\": {\n", - " \"cls\": \"TFIDFRecommender\", # or \"implicit.nearest_neighbours.TFIDFRecommender\"\n", - " \"K\": 50, \n", - " \"num_threads\": 1\n", - " }\n", - "})" - ] - }, - { - "cell_type": "code", - "execution_count": 39, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -610,12 +589,21 @@ " 'model.num_threads': 1}" ] }, - "execution_count": 39, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "config = {\n", + " \"model\": {\n", + " \"cls\": \"TFIDFRecommender\", # or \"implicit.nearest_neighbours.TFIDFRecommender\"\n", + " \"K\": 50, \n", + " \"num_threads\": 1\n", + " } \n", + "}\n", + "\n", + "model = ImplicitItemKNNWrapperModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -636,29 +624,17 @@ }, { "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "config = {\n", - " \"model\": {\n", - " # \"cls\": \"AlternatingLeastSquares\", # will work too\n", - " # \"cls\": \"implicit.als.AlternatingLeastSquares\", # will work too\n", - " \"factors\": 16,\n", - " \"num_threads\": 2,\n", - " \"iterations\": 2,\n", - " \"random_state\": 32\n", - " },\n", - " \"fit_features_together\": True,\n", - "}\n", - "model = ImplicitALSWrapperModel.from_config(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, + "execution_count": 24, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/implicit/cpu/als.py:95: RuntimeWarning: OpenBLAS is configured to use 8 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n", + " check_blas_config()\n" + ] + }, { "data": { "text/plain": [ @@ -676,15 +652,30 @@ " 'model.calculate_training_loss': False,\n", " 'model.num_threads': 2,\n", " 'model.random_state': 32,\n", - " 'fit_features_together': True}" + " 'fit_features_together': True,\n", + " 'recommend_n_threads': None,\n", + " 'recommend_use_gpu_ranking': None}" ] }, - "execution_count": 21, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "config = {\n", + " \"model\": {\n", + " # \"cls\": \"AlternatingLeastSquares\", # will work too\n", + " # \"cls\": \"implicit.als.AlternatingLeastSquares\", # will work too\n", + " \"factors\": 16,\n", + " \"num_threads\": 2,\n", + " \"iterations\": 2,\n", + " \"random_state\": 32\n", + " },\n", + " \"fit_features_together\": True,\n", + "}\n", + "\n", + "model = ImplicitALSWrapperModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -705,27 +696,7 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "config = {\n", - " \"model\": {\n", - " # \"cls\": \"BayesianPersonalizedRanking\", # will work too\n", - " # \"cls\": \"implicit.bpr.BayesianPersonalizedRanking\", # will work too\n", - " \"factors\": 16,\n", - " \"num_threads\": 2,\n", - " \"iterations\": 2,\n", - " \"random_state\": 32\n", - " },\n", - " \"recommend_use_gpu_ranking\": False,\n", - "}\n", - "model = ImplicitBPRWrapperModel.from_config(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -737,21 +708,35 @@ " 'model.factors': 16,\n", " 'model.learning_rate': 0.01,\n", " 'model.regularization': 0.01,\n", - " 'model.dtype': 'float64',\n", + " 'model.dtype': 'float32',\n", + " 'model.num_threads': 2,\n", " 'model.iterations': 2,\n", " 'model.verify_negative_samples': True,\n", " 'model.random_state': 32,\n", - " 'model.use_gpu': True,\n", + " 'model.use_gpu': False,\n", " 'recommend_n_threads': None,\n", " 'recommend_use_gpu_ranking': False}" ] }, - "execution_count": 13, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "config = {\n", + " \"model\": {\n", + " # \"cls\": \"BayesianPersonalizedRanking\", # will work too\n", + " # \"cls\": \"implicit.bpr.BayesianPersonalizedRanking\", # will work too\n", + " \"factors\": 16,\n", + " \"num_threads\": 2,\n", + " \"iterations\": 2,\n", + " \"random_state\": 32\n", + " },\n", + " \"recommend_use_gpu_ranking\": False,\n", + "}\n", + "\n", + "model = ImplicitBPRWrapperModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -764,34 +749,31 @@ }, { "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "config = {\n", - " \"regularization\": 100,\n", - " \"verbose\": 1,\n", - "}\n", - "model = EASEModel.from_config(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'cls': 'EASEModel', 'verbose': 1, 'regularization': 100.0, 'num_threads': 1}" + "{'cls': 'EASEModel',\n", + " 'verbose': 1,\n", + " 'regularization': 100.0,\n", + " 'recommend_n_threads': 0,\n", + " 'recommend_use_gpu_ranking': True}" ] }, - "execution_count": 23, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "config = {\n", + " \"regularization\": 100,\n", + " \"verbose\": 1,\n", + "}\n", + "\n", + "model = EASEModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -804,19 +786,7 @@ }, { "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "config = {\n", - " \"factors\": 32,\n", - "}\n", - "model = PureSVDModel.from_config(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -827,15 +797,22 @@ " 'factors': 32,\n", " 'tol': 0.0,\n", " 'maxiter': None,\n", - " 'random_state': None}" + " 'random_state': None,\n", + " 'recommend_n_threads': 0,\n", + " 'recommend_use_gpu_ranking': True}" ] }, - "execution_count": 25, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "config = {\n", + " \"factors\": 32,\n", + "}\n", + "\n", + "model = PureSVDModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -857,27 +834,7 @@ }, { "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [], - "source": [ - "config = {\n", - " \"model\": {\n", - " # \"cls\": \"lightfm.lightfm.LightFM\", # will work too \n", - " # \"cls\": \"LightFM\", # will work too \n", - " \"no_components\": 16,\n", - " \"learning_rate\": 0.03,\n", - " \"random_state\": 32,\n", - " \"loss\": \"warp\"\n", - " },\n", - " \"epochs\": 2,\n", - "}\n", - "model = LightFMWrapperModel.from_config(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -899,15 +856,30 @@ " 'model.max_sampled': 10,\n", " 'model.random_state': 32,\n", " 'epochs': 2,\n", - " 'num_threads': 1}" + " 'num_threads': 1,\n", + " 'recommend_n_threads': None,\n", + " 'recommend_use_gpu_ranking': True}" ] }, - "execution_count": 31, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "config = {\n", + " \"model\": {\n", + " # \"cls\": \"lightfm.lightfm.LightFM\", # will work too \n", + " # \"cls\": \"LightFM\", # will work too \n", + " \"no_components\": 16,\n", + " \"learning_rate\": 0.03,\n", + " \"random_state\": 32,\n", + " \"loss\": \"warp\"\n", + " },\n", + " \"epochs\": 2,\n", + "}\n", + "\n", + "model = LightFMWrapperModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -920,21 +892,7 @@ }, { "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [], - "source": [ - "from datetime import timedelta\n", - "config = {\n", - " \"popularity\": \"n_interactions\",\n", - " \"period\": timedelta(weeks=2),\n", - "}\n", - "model = PopularModel.from_config(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 33, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -949,12 +907,19 @@ " 'inverse': False}" ] }, - "execution_count": 33, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "from datetime import timedelta\n", + "config = {\n", + " \"popularity\": \"n_interactions\",\n", + " \"period\": timedelta(weeks=2),\n", + "}\n", + "\n", + "model = PopularModel.from_config(config)\n", "model.get_params(simple_types=True)" ] }, @@ -967,22 +932,7 @@ }, { "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [], - "source": [ - "config = {\n", - " \"popularity\": \"n_interactions\",\n", - " \"period\": timedelta(days=1),\n", - " \"category_feature\": \"genres\",\n", - " \"mixing_strategy\": \"group\"\n", - "}\n", - "model = PopularInCategoryModel.from_config(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -1001,13 +951,21 @@ " 'ratio_strategy': 'proportional'}" ] }, - "execution_count": 35, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "model.get_params(simple_types=True)\n" + "config = {\n", + " \"popularity\": \"n_interactions\",\n", + " \"period\": timedelta(days=1),\n", + " \"category_feature\": \"genres\",\n", + " \"mixing_strategy\": \"group\"\n", + "}\n", + "\n", + "model = PopularInCategoryModel.from_config(config)\n", + "model.get_params(simple_types=True)" ] }, { @@ -1019,19 +977,7 @@ }, { "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [], - "source": [ - "config = {\n", - " \"random_state\": 32,\n", - "}\n", - "model = RandomModel.from_config(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -1040,12 +986,17 @@ "{'cls': 'RandomModel', 'verbose': 0, 'random_state': 32}" ] }, - "execution_count": 37, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "config = {\n", + " \"random_state\": 32,\n", + "}\n", + "\n", + "model = RandomModel.from_config(config)\n", "model.get_params(simple_types=True)" ] } From d6fa3b21065d75de02305038df2b27bfe460886b Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 28 Jan 2025 11:28:23 +0300 Subject: [PATCH 18/25] fixed configs --- rectools/models/nn/bert4rec.py | 37 ++----------------- rectools/models/nn/sasrec.py | 34 ++--------------- rectools/models/nn/transformer_base.py | 8 +--- .../nn/{test_bertrec.py => test_bert4rec.py} | 0 4 files changed, 10 insertions(+), 69 deletions(-) rename tests/models/nn/{test_bertrec.py => test_bert4rec.py} (100%) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index 43c7e59d..f1f4c4a6 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -363,36 +363,7 @@ def _get_config(self) -> BERT4RecModelConfig: @classmethod def _from_config(cls, config: BERT4RecModelConfig) -> tpe.Self: - return cls( - trainer=None, - n_blocks=config.n_blocks, - n_heads=config.n_heads, - n_factors=config.n_factors, - use_pos_emb=config.use_pos_emb, - use_causal_attn=config.use_causal_attn, - use_key_padding_mask=config.use_key_padding_mask, - dropout_rate=config.dropout_rate, - session_max_len=config.session_max_len, - dataloader_num_workers=config.dataloader_num_workers, - batch_size=config.batch_size, - loss=config.loss, - n_negatives=config.n_negatives, - gbce_t=config.gbce_t, - lr=config.lr, - epochs=config.epochs, - verbose=config.verbose, - deterministic=config.deterministic, - recommend_devices=config.recommend_devices, - recommend_accelerator=config.recommend_accelerator, - recommend_batch_size=config.recommend_batch_size, - recommend_n_threads=config.recommend_n_threads, - recommend_use_gpu_ranking=config.recommend_use_gpu_ranking, - train_min_user_interactions=config.train_min_user_interactions, - item_net_block_types=config.item_net_block_types, - pos_encoding_type=config.pos_encoding_type, - transformer_layers_type=config.transformer_layers_type, - data_preparator_type=config.data_preparator_type, - lightning_module_type=config.lightning_module_type, - mask_prob=config.mask_prob, - get_val_mask_func=config.get_val_mask_func, - ) + params = config.model_dump() + params.pop("cls") + params["trainer"] = None + return cls(**params) diff --git a/rectools/models/nn/sasrec.py b/rectools/models/nn/sasrec.py index e22d0dea..4b0622dd 100644 --- a/rectools/models/nn/sasrec.py +++ b/rectools/models/nn/sasrec.py @@ -397,33 +397,7 @@ def _get_config(self) -> SASRecModelConfig: @classmethod def _from_config(cls, config: SASRecModelConfig) -> tpe.Self: - return cls( - trainer=None, - n_blocks=config.n_blocks, - n_heads=config.n_heads, - n_factors=config.n_factors, - use_pos_emb=config.use_pos_emb, - use_causal_attn=config.use_causal_attn, - use_key_padding_mask=config.use_key_padding_mask, - dropout_rate=config.dropout_rate, - session_max_len=config.session_max_len, - dataloader_num_workers=config.dataloader_num_workers, - batch_size=config.batch_size, - loss=config.loss, - n_negatives=config.n_negatives, - gbce_t=config.gbce_t, - lr=config.lr, - epochs=config.epochs, - verbose=config.verbose, - deterministic=config.deterministic, - recommend_devices=config.recommend_devices, - recommend_n_threads=config.recommend_n_threads, - recommend_use_gpu_ranking=config.recommend_use_gpu_ranking, - train_min_user_interactions=config.train_min_user_interactions, - item_net_block_types=config.item_net_block_types, - pos_encoding_type=config.pos_encoding_type, - transformer_layers_type=config.transformer_layers_type, - data_preparator_type=config.data_preparator_type, - lightning_module_type=config.lightning_module_type, - get_val_mask_func=config.get_val_mask_func, - ) + params = config.model_dump() + params.pop("cls") + params["trainer"] = None + return cls(**params) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 85aefdde..b119f8b0 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -471,11 +471,11 @@ def _get_class_obj(spec: tp.Any) -> tp.Any: return import_object(spec) -def _get_class_obj_sequence(spec: tp.Sequence[tp.Any]) -> tp.Any: +def _get_class_obj_sequence(spec: tp.Sequence[tp.Any]) -> tp.Tuple[tp.Any, ...]: return tuple(map(_get_class_obj, spec)) -def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Sequence[str]: +def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: return tuple(map(get_class_or_function_full_path, obj)) @@ -529,10 +529,6 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Sequence[str]: ), ] -SessionEncoderDataPreparatorType_T = tp.TypeVar( - "SessionEncoderDataPreparatorType_T", bound=SessionEncoderDataPreparatorType -) - CallableSerialized = tpe.Annotated[ tp.Callable, BeforeValidator(_get_class_obj), diff --git a/tests/models/nn/test_bertrec.py b/tests/models/nn/test_bert4rec.py similarity index 100% rename from tests/models/nn/test_bertrec.py rename to tests/models/nn/test_bert4rec.py From 91df626cb035b58aa2534f2a9b97acab78cf3b9a Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 28 Jan 2025 14:05:45 +0300 Subject: [PATCH 19/25] updated tests --- tests/models/nn/test_bert4rec.py | 197 ++++++------------------------- tests/models/nn/test_sasrec.py | 190 +++++++---------------------- tests/models/nn/utils.py | 12 ++ 3 files changed, 96 insertions(+), 303 deletions(-) create mode 100644 tests/models/nn/utils.py diff --git a/tests/models/nn/test_bert4rec.py b/tests/models/nn/test_bert4rec.py index 6ce4e9a7..9c5c82d1 100644 --- a/tests/models/nn/test_bert4rec.py +++ b/tests/models/nn/test_bert4rec.py @@ -20,7 +20,6 @@ import torch from pytorch_lightning import seed_everything -from rectools import Columns from rectools.models import BERT4RecModel from rectools.models.nn.bert4rec import BERT4RecDataPreparator from rectools.models.nn.item_net import IdEmbeddingsItemNet @@ -32,14 +31,7 @@ from tests.models.data import DATASET from tests.models.utils import assert_default_config_and_default_model_params_are_the_same - -def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series: - rank = ( - interactions.sort_values(Columns.Datetime, ascending=False, kind="stable") - .groupby(Columns.User, sort=False) - .cumcount() - ) - return rank == 0 +from .utils import leave_one_out_mask class TestBERT4RecModelConfiguration: @@ -50,8 +42,8 @@ def _seed_everything(self) -> None: torch.use_deterministic_algorithms(True) seed_everything(32, workers=True) - def test_from_config(self) -> None: - + @pytest.fixture + def initial_config(self) -> tp.Dict[str, tp.Any]: config = { "n_blocks": 2, "n_heads": 4, @@ -61,7 +53,7 @@ def test_from_config(self) -> None: "use_key_padding_mask": True, "dropout_rate": 0.5, "session_max_len": 10, - "dataloader_num_workers": 5, + "dataloader_num_workers": 0, "batch_size": 1024, "loss": "softmax", "n_negatives": 10, @@ -75,7 +67,7 @@ def test_from_config(self) -> None: "recommend_batch_size": 256, "recommend_n_threads": 0, "recommend_use_gpu_ranking": True, - "train_min_user_interactions": 5, + "train_min_user_interactions": 2, "item_net_block_types": (IdEmbeddingsItemNet,), "pos_encoding_type": LearnableInversePositionalEncoding, "transformer_layers_type": PreLNTransformerLayers, @@ -84,164 +76,53 @@ def test_from_config(self) -> None: "mask_prob": 0.15, "get_val_mask_func": leave_one_out_mask, } - model = BERT4RecModel.from_config(config) - assert model.n_blocks == 2 - assert model.n_heads == 4 - assert model.n_factors == 64 - assert model.use_pos_emb is False - assert model.use_causal_attn is False - assert model.use_key_padding_mask is True - assert model.dropout_rate == 0.5 - assert model.session_max_len == 10 - assert model.dataloader_num_workers == 5 - assert model.batch_size == 1024 - assert model.loss == "softmax" - assert model.n_negatives == 10 - assert model.gbce_t == 0.5 - assert model.lr == 0.001 - assert model.epochs == 10 - assert model.verbose == 1 - assert model.deterministic is True - assert model.recommend_accelerator == "auto" - assert model.recommend_devices == 1 - assert model.recommend_batch_size == 256 - assert model.recommend_n_threads == 0 - assert model.recommend_use_gpu_ranking is True - assert model.train_min_user_interactions == 5 + return config + + def test_from_config(self, initial_config: tp.Dict[str, tp.Any]) -> None: + model = BERT4RecModel.from_config(initial_config) + + for key, config_value in initial_config.items(): + assert getattr(model, key) == config_value + assert model._trainer is not None # pylint: disable = protected-access - assert model.item_net_block_types == (IdEmbeddingsItemNet,) - assert model.pos_encoding_type == LearnableInversePositionalEncoding - assert model.transformer_layers_type == PreLNTransformerLayers - assert model.data_preparator_type == BERT4RecDataPreparator - assert model.lightning_module_type == SessionEncoderLightningModule - assert model.mask_prob == 0.15 - assert model.get_val_mask_func is leave_one_out_mask @pytest.mark.parametrize("simple_types", (False, True)) - def test_get_config(self, simple_types: bool) -> None: - model = BERT4RecModel( - n_blocks=2, - n_heads=4, - n_factors=64, - use_pos_emb=False, - use_causal_attn=False, - use_key_padding_mask=True, - dropout_rate=0.5, - session_max_len=10, - dataloader_num_workers=5, - batch_size=1024, - loss="softmax", - n_negatives=10, - gbce_t=0.5, - lr=0.001, - epochs=10, - verbose=1, - deterministic=True, - recommend_accelerator="auto", - recommend_devices=1, - recommend_batch_size=256, - recommend_n_threads=0, - recommend_use_gpu_ranking=True, - train_min_user_interactions=5, - item_net_block_types=(IdEmbeddingsItemNet,), - pos_encoding_type=LearnableInversePositionalEncoding, - transformer_layers_type=PreLNTransformerLayers, - data_preparator_type=BERT4RecDataPreparator, - lightning_module_type=SessionEncoderLightningModule, - mask_prob=0.15, - get_val_mask_func=leave_one_out_mask, - ) + def test_get_config(self, simple_types: bool, initial_config: tp.Dict[str, tp.Any]) -> None: + model = BERT4RecModel(**initial_config) config = model.get_config(simple_types=simple_types) - expected = { - "cls": "BERT4RecModel" if simple_types else BERT4RecModel, - "n_blocks": 2, - "n_heads": 4, - "n_factors": 64, - "use_pos_emb": False, - "use_causal_attn": False, - "use_key_padding_mask": True, - "dropout_rate": 0.5, - "session_max_len": 10, - "dataloader_num_workers": 5, - "batch_size": 1024, - "loss": "softmax", - "n_negatives": 10, - "gbce_t": 0.5, - "lr": 0.001, - "epochs": 10, - "verbose": 1, - "deterministic": True, - "recommend_accelerator": "auto", - "recommend_devices": 1, - "recommend_batch_size": 256, - "recommend_n_threads": 0, - "recommend_use_gpu_ranking": True, - "train_min_user_interactions": 5, - "item_net_block_types": ( - ["rectools.models.nn.item_net.IdEmbeddingsItemNet"] if simple_types else (IdEmbeddingsItemNet,) - ), - "pos_encoding_type": ( - "rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding" - if simple_types - else LearnableInversePositionalEncoding - ), - "transformer_layers_type": ( - "rectools.models.nn.transformer_net_blocks.PreLNTransformerLayers" - if simple_types - else PreLNTransformerLayers - ), - "data_preparator_type": ( - "rectools.models.nn.bert4rec.BERT4RecDataPreparator" if simple_types else BERT4RecDataPreparator - ), - "lightning_module_type": ( - "rectools.models.nn.transformer_base.SessionEncoderLightningModule" - if simple_types - else SessionEncoderLightningModule - ), - "mask_prob": 0.15, - "get_val_mask_func": ( - "tests.models.nn.test_bertrec.leave_one_out_mask" if simple_types else leave_one_out_mask - ), - } + + expected = initial_config.copy() + expected["cls"] = BERT4RecModel + + if simple_types: + simple_types_params = { + "cls": "BERT4RecModel", + "item_net_block_types": ["rectools.models.nn.item_net.IdEmbeddingsItemNet"], + "pos_encoding_type": "rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding", + "transformer_layers_type": "rectools.models.nn.transformer_net_blocks.PreLNTransformerLayers", + "data_preparator_type": "rectools.models.nn.bert4rec.BERT4RecDataPreparator", + "lightning_module_type": "rectools.models.nn.transformer_base.SessionEncoderLightningModule", + "get_val_mask_func": "tests.models.nn.test_bert4rec.leave_one_out_mask", + } + expected.update(simple_types_params) + assert config == expected @pytest.mark.parametrize("simple_types", (False, True)) - def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: - initial_config = { + def test_get_config_and_from_config_compatibility( + self, simple_types: bool, initial_config: tp.Dict[str, tp.Any] + ) -> None: + dataset = DATASET + model = BERT4RecModel + updated_params = { "n_blocks": 1, "n_heads": 1, "n_factors": 10, - "use_pos_emb": False, - "use_causal_attn": False, - "use_key_padding_mask": True, - "dropout_rate": 0.5, "session_max_len": 5, - "dataloader_num_workers": 1, - "batch_size": 100, - "loss": "softmax", - "n_negatives": 4, - "gbce_t": 0.5, - "lr": 0.001, "epochs": 1, - "verbose": 0, - "deterministic": True, - "recommend_accelerator": "auto", - "recommend_devices": 1, - "recommend_batch_size": 256, - "recommend_n_threads": 0, - "recommend_use_gpu_ranking": True, - "train_min_user_interactions": 2, - "item_net_block_types": (IdEmbeddingsItemNet,), - "pos_encoding_type": LearnableInversePositionalEncoding, - "transformer_layers_type": PreLNTransformerLayers, - "data_preparator_type": BERT4RecDataPreparator, - "lightning_module_type": SessionEncoderLightningModule, - "mask_prob": 0.15, - "get_val_mask_func": leave_one_out_mask, } - - dataset = DATASET - model = BERT4RecModel + config = initial_config.copy() + config.update(updated_params) def get_reco(model: BERT4RecModel) -> pd.DataFrame: return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index df1f39fa..db54e6a7 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -39,6 +39,8 @@ ) from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal +from .utils import leave_one_out_mask + class TestSASRecModel: def setup_method(self) -> None: @@ -831,19 +833,20 @@ def _seed_everything(self) -> None: torch.use_deterministic_algorithms(True) seed_everything(32, workers=True) - def test_from_config(self) -> None: + @pytest.fixture + def initial_config(self) -> tp.Dict[str, tp.Any]: config = { "n_blocks": 2, "n_heads": 4, "n_factors": 64, - "use_pos_emb": False, - "use_causal_attn": False, - "use_key_padding_mask": True, + "use_pos_emb": True, + "use_causal_attn": True, + "use_key_padding_mask": False, "dropout_rate": 0.5, "session_max_len": 10, - "dataloader_num_workers": 5, + "dataloader_num_workers": 0, "batch_size": 1024, - "loss": "BCE", + "loss": "softmax", "n_negatives": 10, "gbce_t": 0.5, "lr": 0.001, @@ -855,164 +858,61 @@ def test_from_config(self) -> None: "recommend_batch_size": 256, "recommend_n_threads": 0, "recommend_use_gpu_ranking": True, - "train_min_user_interactions": 5, + "train_min_user_interactions": 2, "item_net_block_types": (IdEmbeddingsItemNet,), "pos_encoding_type": LearnableInversePositionalEncoding, "transformer_layers_type": SASRecTransformerLayers, "data_preparator_type": SASRecDataPreparator, "lightning_module_type": SessionEncoderLightningModule, - "get_val_mask_func": None, + "get_val_mask_func": leave_one_out_mask, } - model = SASRecModel.from_config(config) - assert model.n_blocks == 2 - assert model.n_heads == 4 - assert model.n_factors == 64 - assert model.use_pos_emb is False - assert model.use_causal_attn is False - assert model.use_key_padding_mask is True - assert model.dropout_rate == 0.5 - assert model.session_max_len == 10 - assert model.dataloader_num_workers == 5 - assert model.batch_size == 1024 - assert model.loss == "BCE" - assert model.n_negatives == 10 - assert model.gbce_t == 0.5 - assert model.lr == 0.001 - assert model.epochs == 10 - assert model.verbose == 1 - assert model.deterministic is True - assert model.recommend_accelerator == "auto" - assert model.recommend_devices == 1 - assert model.recommend_batch_size == 256 - assert model.recommend_n_threads == 0 - assert model.recommend_use_gpu_ranking is True - assert model.train_min_user_interactions == 5 + return config + + def test_from_config(self, initial_config: tp.Dict[str, tp.Any]) -> None: + model = SASRecModel.from_config(initial_config) + + for key, config_value in initial_config.items(): + assert getattr(model, key) == config_value + assert model._trainer is not None # pylint: disable = protected-access - assert model.item_net_block_types == (IdEmbeddingsItemNet,) - assert model.pos_encoding_type == LearnableInversePositionalEncoding - assert model.transformer_layers_type == SASRecTransformerLayers - assert model.data_preparator_type == SASRecDataPreparator - assert model.lightning_module_type == SessionEncoderLightningModule - assert model.get_val_mask_func is None @pytest.mark.parametrize("simple_types", (False, True)) - def test_get_config(self, simple_types: bool) -> None: - model = SASRecModel( - n_blocks=2, - n_heads=4, - n_factors=64, - use_pos_emb=False, - use_causal_attn=False, - use_key_padding_mask=True, - dropout_rate=0.5, - session_max_len=10, - dataloader_num_workers=5, - batch_size=1024, - loss="BCE", - n_negatives=10, - gbce_t=0.5, - lr=0.001, - epochs=10, - verbose=1, - deterministic=True, - recommend_accelerator="auto", - recommend_devices=1, - recommend_batch_size=256, - recommend_n_threads=0, - recommend_use_gpu_ranking=True, - train_min_user_interactions=5, - item_net_block_types=(IdEmbeddingsItemNet,), - pos_encoding_type=LearnableInversePositionalEncoding, - transformer_layers_type=SASRecTransformerLayers, - data_preparator_type=SASRecDataPreparator, - lightning_module_type=SessionEncoderLightningModule, - get_val_mask_func=None, - ) + def test_get_config(self, simple_types: bool, initial_config: tp.Dict[str, tp.Any]) -> None: + model = SASRecModel(**initial_config) config = model.get_config(simple_types=simple_types) - expected = { - "cls": "SASRecModel" if simple_types else SASRecModel, - "n_blocks": 2, - "n_heads": 4, - "n_factors": 64, - "use_pos_emb": False, - "use_causal_attn": False, - "use_key_padding_mask": True, - "dropout_rate": 0.5, - "session_max_len": 10, - "dataloader_num_workers": 5, - "batch_size": 1024, - "loss": "BCE", - "n_negatives": 10, - "gbce_t": 0.5, - "lr": 0.001, - "epochs": 10, - "verbose": 1, - "deterministic": True, - "recommend_accelerator": "auto", - "recommend_devices": 1, - "recommend_batch_size": 256, - "recommend_n_threads": 0, - "recommend_use_gpu_ranking": True, - "train_min_user_interactions": 5, - "item_net_block_types": ( - ["rectools.models.nn.item_net.IdEmbeddingsItemNet"] if simple_types else (IdEmbeddingsItemNet,) - ), - "pos_encoding_type": ( - "rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding" - if simple_types - else LearnableInversePositionalEncoding - ), - "transformer_layers_type": ( - "rectools.models.nn.sasrec.SASRecTransformerLayers" if simple_types else SASRecTransformerLayers - ), - "data_preparator_type": ( - "rectools.models.nn.sasrec.SASRecDataPreparator" if simple_types else SASRecDataPreparator - ), - "lightning_module_type": ( - "rectools.models.nn.transformer_base.SessionEncoderLightningModule" - if simple_types - else SessionEncoderLightningModule - ), - "get_val_mask_func": None, - } + + expected = initial_config.copy() + expected["cls"] = SASRecModel + + if simple_types: + simple_types_params = { + "cls": "SASRecModel", + "item_net_block_types": ["rectools.models.nn.item_net.IdEmbeddingsItemNet"], + "pos_encoding_type": "rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding", + "transformer_layers_type": "rectools.models.nn.transformer_net_blocks.SASRecTransformerLayers", + "data_preparator_type": "rectools.models.nn.bert4rec.SASRecDataPreparator", + "lightning_module_type": "rectools.models.nn.transformer_base.SessionEncoderLightningModule", + "get_val_mask_func": "tests.models.nn.test_bert4rec.leave_one_out_mask", + } + expected.update(simple_types_params) + assert config == expected @pytest.mark.parametrize("simple_types", (False, True)) - def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: - initial_config = { + def test_get_config_and_from_config_compatibility( + self, simple_types: bool, initial_config: tp.Dict[str, tp.Any] + ) -> None: + dataset = DATASET + model = SASRecModel + updated_params = { "n_blocks": 1, "n_heads": 1, "n_factors": 10, - "use_pos_emb": False, - "use_causal_attn": False, - "use_key_padding_mask": True, - "dropout_rate": 0.5, "session_max_len": 5, - "dataloader_num_workers": 1, - "batch_size": 100, - "loss": "BCE", - "n_negatives": 4, - "gbce_t": 0.5, - "lr": 0.001, "epochs": 1, - "verbose": 0, - "deterministic": True, - "recommend_accelerator": "auto", - "recommend_devices": 1, - "recommend_batch_size": 256, - "recommend_n_threads": 0, - "recommend_use_gpu_ranking": True, - "train_min_user_interactions": 2, - "item_net_block_types": (IdEmbeddingsItemNet,), - "pos_encoding_type": LearnableInversePositionalEncoding, - "transformer_layers_type": SASRecTransformerLayers, - "data_preparator_type": SASRecDataPreparator, - "lightning_module_type": SessionEncoderLightningModule, - "get_val_mask_func": None, } - - dataset = DATASET - model = SASRecModel + config = initial_config.copy() + config.update(updated_params) def get_reco(model: SASRecModel) -> pd.DataFrame: return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) diff --git a/tests/models/nn/utils.py b/tests/models/nn/utils.py new file mode 100644 index 00000000..a941e7c0 --- /dev/null +++ b/tests/models/nn/utils.py @@ -0,0 +1,12 @@ +import pandas as pd + +from rectools import Columns + + +def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series: + rank = ( + interactions.sort_values(Columns.Datetime, ascending=False, kind="stable") + .groupby(Columns.User, sort=False) + .cumcount() + ) + return rank == 0 From 1ea36d8eb6573805c38a39de46f4a55b93862146 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 28 Jan 2025 14:07:51 +0300 Subject: [PATCH 20/25] copyright --- tests/models/nn/utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/models/nn/utils.py b/tests/models/nn/utils.py index a941e7c0..7a74aebb 100644 --- a/tests/models/nn/utils.py +++ b/tests/models/nn/utils.py @@ -1,3 +1,17 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pandas as pd from rectools import Columns From 871cb0939041ef35fb9470f02e5dd41761925870 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 28 Jan 2025 14:15:48 +0300 Subject: [PATCH 21/25] fixed accelerator test --- rectools/models/nn/transformer_base.py | 4 ++-- tests/models/nn/test_sasrec.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index b119f8b0..9c868092 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -708,7 +708,7 @@ def _init_torch_model(self) -> None: pos_encoding_type=self.pos_encoding_type, ) - def _init_lightning_model(self, torch_model: TransformerBasedSessionEncoder, n_item_extra_tokens: int) -> None: + def _init_lightning_model(self, torch_model: TransformerBasedSessionEncoder) -> None: self.lightning_model = self.lightning_module_type( torch_model=torch_model, lr=self.lr, @@ -731,7 +731,7 @@ def _fit( torch_model = deepcopy(self._torch_model) torch_model.construct_item_net(self.data_preparator.train_dataset) - self._init_lightning_model(torch_model, self.data_preparator.n_item_extra_tokens) + self._init_lightning_model(torch_model) self.fit_trainer = deepcopy(self._trainer) self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader) diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index db54e6a7..028b26e1 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -105,7 +105,7 @@ def get_val_mask(interactions: pd.DataFrame, val_users: ExternalIds) -> pd.Serie return get_val_mask_func @pytest.mark.parametrize( - "accelerator,n_devices,recommend_device", + "accelerator,devices,recommend_accelerator", [ ("cpu", 1, "cpu"), pytest.param( @@ -196,8 +196,8 @@ def test_u2i( dataset: Dataset, filter_viewed: bool, accelerator: str, - n_devices: int, - recommend_device: str, + devices: tp.Union[int, tp.List[int]], + recommend_accelerator: str, expected_cpu_1: pd.DataFrame, expected_cpu_2: pd.DataFrame, expected_gpu: pd.DataFrame, @@ -206,7 +206,7 @@ def test_u2i( max_epochs=2, min_epochs=2, deterministic=True, - devices=n_devices, + devices=devices, accelerator=accelerator, ) model = SASRecModel( @@ -217,16 +217,16 @@ def test_u2i( batch_size=4, epochs=2, deterministic=True, - recommend_accelerator=recommend_device, + recommend_accelerator=recommend_accelerator, item_net_block_types=(IdEmbeddingsItemNet,), trainer=trainer, ) model.fit(dataset=dataset) users = np.array([10, 30, 40]) actual = model.recommend(users=users, dataset=dataset, k=3, filter_viewed=filter_viewed) - if accelerator == "cpu" and n_devices == 1: + if accelerator == "cpu" and devices == 1: expected = expected_cpu_1 - elif accelerator == "cpu" and n_devices == 2: + elif accelerator == "cpu" and devices == 2: expected = expected_cpu_2 else: expected = expected_gpu From aadadda7071582a8d40f1a0c4c3e32fc3f7c4877 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 28 Jan 2025 14:25:23 +0300 Subject: [PATCH 22/25] fixed tests --- tests/models/nn/test_bert4rec.py | 2 +- tests/models/nn/test_sasrec.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/nn/test_bert4rec.py b/tests/models/nn/test_bert4rec.py index 9c5c82d1..4bd891ef 100644 --- a/tests/models/nn/test_bert4rec.py +++ b/tests/models/nn/test_bert4rec.py @@ -102,7 +102,7 @@ def test_get_config(self, simple_types: bool, initial_config: tp.Dict[str, tp.An "transformer_layers_type": "rectools.models.nn.transformer_net_blocks.PreLNTransformerLayers", "data_preparator_type": "rectools.models.nn.bert4rec.BERT4RecDataPreparator", "lightning_module_type": "rectools.models.nn.transformer_base.SessionEncoderLightningModule", - "get_val_mask_func": "tests.models.nn.test_bert4rec.leave_one_out_mask", + "get_val_mask_func": "tests.models.nn.utils.leave_one_out_mask", } expected.update(simple_types_params) diff --git a/tests/models/nn/test_sasrec.py b/tests/models/nn/test_sasrec.py index 028b26e1..d1d56b8f 100644 --- a/tests/models/nn/test_sasrec.py +++ b/tests/models/nn/test_sasrec.py @@ -889,10 +889,10 @@ def test_get_config(self, simple_types: bool, initial_config: tp.Dict[str, tp.An "cls": "SASRecModel", "item_net_block_types": ["rectools.models.nn.item_net.IdEmbeddingsItemNet"], "pos_encoding_type": "rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding", - "transformer_layers_type": "rectools.models.nn.transformer_net_blocks.SASRecTransformerLayers", - "data_preparator_type": "rectools.models.nn.bert4rec.SASRecDataPreparator", + "transformer_layers_type": "rectools.models.nn.sasrec.SASRecTransformerLayers", + "data_preparator_type": "rectools.models.nn.sasrec.SASRecDataPreparator", "lightning_module_type": "rectools.models.nn.transformer_base.SessionEncoderLightningModule", - "get_val_mask_func": "tests.models.nn.test_bert4rec.leave_one_out_mask", + "get_val_mask_func": "tests.models.nn.utils.leave_one_out_mask", } expected.update(simple_types_params) From 8e7cadece070ec79106a346d94a36e465b49190a Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 28 Jan 2025 14:30:41 +0300 Subject: [PATCH 23/25] moved _from_config to transformer base --- rectools/models/nn/bert4rec.py | 8 -------- rectools/models/nn/sasrec.py | 8 -------- rectools/models/nn/transformer_base.py | 7 +++++++ 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index f1f4c4a6..9c49eb91 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -18,7 +18,6 @@ import numpy as np import torch -import typing_extensions as tpe from pytorch_lightning import Trainer from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase @@ -360,10 +359,3 @@ def _get_config(self) -> BERT4RecModelConfig: mask_prob=self.mask_prob, get_val_mask_func=self.get_val_mask_func, ) - - @classmethod - def _from_config(cls, config: BERT4RecModelConfig) -> tpe.Self: - params = config.model_dump() - params.pop("cls") - params["trainer"] = None - return cls(**params) diff --git a/rectools/models/nn/sasrec.py b/rectools/models/nn/sasrec.py index 4b0622dd..71d8957d 100644 --- a/rectools/models/nn/sasrec.py +++ b/rectools/models/nn/sasrec.py @@ -17,7 +17,6 @@ import numpy as np import torch -import typing_extensions as tpe from pytorch_lightning import Trainer from torch import nn @@ -394,10 +393,3 @@ def _get_config(self) -> SASRecModelConfig: lightning_module_type=self.lightning_module_type, get_val_mask_func=self.get_val_mask_func, ) - - @classmethod - def _from_config(cls, config: SASRecModelConfig) -> tpe.Self: - params = config.model_dump() - params.pop("cls") - params["trainer"] = None - return cls(**params) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index 9c868092..dc8645e6 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -842,3 +842,10 @@ def _recommend_i2i( def torch_model(self) -> TransformerBasedSessionEncoder: """Pytorch model.""" return self.lightning_model.torch_model + + @classmethod + def _from_config(cls, config: TransformerModelConfig_T) -> tpe.Self: + params = config.model_dump() + params.pop("cls") + params["trainer"] = None + return cls(**params) From 908eb6d1c0435854662d23bac2c31afae13d08b5 Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 28 Jan 2025 15:48:20 +0300 Subject: [PATCH 24/25] simplified _get_config --- rectools/models/nn/bert4rec.py | 35 -------------------------- rectools/models/nn/sasrec.py | 34 ------------------------- rectools/models/nn/transformer_base.py | 6 +++++ 3 files changed, 6 insertions(+), 69 deletions(-) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index 9c49eb91..019f33a2 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -324,38 +324,3 @@ def _init_data_preparator(self) -> None: mask_prob=self.mask_prob, get_val_mask_func=self.get_val_mask_func, ) - - def _get_config(self) -> BERT4RecModelConfig: - return BERT4RecModelConfig( - cls=self.__class__, - n_blocks=self.n_blocks, - n_heads=self.n_heads, - n_factors=self.n_factors, - use_pos_emb=self.use_pos_emb, - use_causal_attn=self.use_causal_attn, - use_key_padding_mask=self.use_key_padding_mask, - dropout_rate=self.dropout_rate, - session_max_len=self.session_max_len, - dataloader_num_workers=self.dataloader_num_workers, - batch_size=self.batch_size, - loss=self.loss, - n_negatives=self.n_negatives, - gbce_t=self.gbce_t, - lr=self.lr, - epochs=self.epochs, - verbose=self.verbose, - deterministic=self.deterministic, - recommend_devices=self.recommend_devices, - recommend_accelerator=self.recommend_accelerator, - recommend_batch_size=self.recommend_batch_size, - recommend_n_threads=self.recommend_n_threads, - recommend_use_gpu_ranking=self.recommend_use_gpu_ranking, - train_min_user_interactions=self.train_min_user_interactions, - item_net_block_types=self.item_net_block_types, - pos_encoding_type=self.pos_encoding_type, - transformer_layers_type=self.transformer_layers_type, - data_preparator_type=self.data_preparator_type, - lightning_module_type=self.lightning_module_type, - mask_prob=self.mask_prob, - get_val_mask_func=self.get_val_mask_func, - ) diff --git a/rectools/models/nn/sasrec.py b/rectools/models/nn/sasrec.py index 71d8957d..4f347b9b 100644 --- a/rectools/models/nn/sasrec.py +++ b/rectools/models/nn/sasrec.py @@ -359,37 +359,3 @@ def _init_data_preparator(self) -> None: train_min_user_interactions=self.train_min_user_interactions, get_val_mask_func=self.get_val_mask_func, ) - - def _get_config(self) -> SASRecModelConfig: - return SASRecModelConfig( - cls=self.__class__, - n_blocks=self.n_blocks, - n_heads=self.n_heads, - n_factors=self.n_factors, - use_pos_emb=self.use_pos_emb, - use_causal_attn=self.use_causal_attn, - use_key_padding_mask=self.use_key_padding_mask, - dropout_rate=self.dropout_rate, - session_max_len=self.session_max_len, - dataloader_num_workers=self.dataloader_num_workers, - batch_size=self.batch_size, - loss=self.loss, - n_negatives=self.n_negatives, - gbce_t=self.gbce_t, - lr=self.lr, - epochs=self.epochs, - verbose=self.verbose, - deterministic=self.deterministic, - recommend_devices=self.recommend_devices, - recommend_accelerator=self.recommend_accelerator, - recommend_batch_size=self.recommend_batch_size, - recommend_n_threads=self.recommend_n_threads, - recommend_use_gpu_ranking=self.recommend_use_gpu_ranking, - train_min_user_interactions=self.train_min_user_interactions, - item_net_block_types=self.item_net_block_types, - pos_encoding_type=self.pos_encoding_type, - transformer_layers_type=self.transformer_layers_type, - data_preparator_type=self.data_preparator_type, - lightning_module_type=self.lightning_module_type, - get_val_mask_func=self.get_val_mask_func, - ) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index dc8645e6..ed548f6b 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -849,3 +849,9 @@ def _from_config(cls, config: TransformerModelConfig_T) -> tpe.Self: params.pop("cls") params["trainer"] = None return cls(**params) + + def _get_config(self) -> TransformerModelConfig_T: + attrs = self.config_class.model_json_schema(mode="serialization")["properties"].keys() + params = {attr: getattr(self, attr) for attr in attrs if attr != "cls"} + params["cls"] = self.__class__ + return self.config_class(**params) From 3b856d042774ea8d875dd33e0aebd7ec9a4e402a Mon Sep 17 00:00:00 2001 From: Daria Tikhonovich Date: Tue, 28 Jan 2025 20:13:04 +0300 Subject: [PATCH 25/25] -1 in bert data preparator --- rectools/models/nn/bert4rec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rectools/models/nn/bert4rec.py b/rectools/models/nn/bert4rec.py index 019f33a2..004e6897 100644 --- a/rectools/models/nn/bert4rec.py +++ b/rectools/models/nn/bert4rec.py @@ -315,7 +315,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals def _init_data_preparator(self) -> None: self.data_preparator: SessionEncoderDataPreparatorBase = self.data_preparator_type( - session_max_len=self.session_max_len, + session_max_len=self.session_max_len - 1, # TODO: remove `-1` n_negatives=self.n_negatives if self.loss != "softmax" else None, batch_size=self.batch_size, dataloader_num_workers=self.dataloader_num_workers,