Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sasrec configs #248

Merged
merged 28 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
532 changes: 365 additions & 167 deletions examples/9_model_configs_and_saving.ipynb

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions rectools/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ class DSSMModel(RequirementUnavailable):
requirement = "torch"


class SASRecModel(RequirementUnavailable):
"""Dummy class, which is returned if there are no dependencies required for the model"""

requirement = "torch"


class BERT4RecModel(RequirementUnavailable):
"""Dummy class, which is returned if there are no dependencies required for the model"""

requirement = "torch"


class ItemToItemAnnRecommender(RequirementUnavailable):
"""Dummy class, which is returned if there are no dependencies required for the model"""

Expand Down
87 changes: 61 additions & 26 deletions rectools/models/nn/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@

import typing as tp
from collections.abc import Hashable
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple

import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import Accelerator

from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase
from .transformer_base import (
PADDING_VALUE,
SessionEncoderDataPreparatorType,
SessionEncoderLightningModule,
SessionEncoderLightningModuleBase,
TransformerModelBase,
TransformerModelConfig,
)
from .transformer_data_preparator import SessionEncoderDataPreparatorBase
from .transformer_net_blocks import (
Expand Down Expand Up @@ -144,7 +145,15 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> D
return {"x": torch.LongTensor(x)}


class BERT4RecModel(TransformerModelBase):
class BERT4RecModelConfig(TransformerModelConfig):
"""BERT4RecModel config."""

data_preparator_type: SessionEncoderDataPreparatorType = BERT4RecDataPreparator
use_key_padding_mask: bool = True
mask_prob: float = 0.15
feldlime marked this conversation as resolved.
Show resolved Hide resolved


class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
"""
BERT4Rec model.

Expand Down Expand Up @@ -190,10 +199,21 @@ class BERT4RecModel(TransformerModelBase):
deterministic : bool, default ``False``
If ``True``, set deterministic algorithms for PyTorch operations.
Use `pytorch_lightning.seed_everything` together with this parameter to fix the random state.
recommend_device : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"} or Accelerator, default "auto"
Device for recommend. Used at predict_step of lightning module.
recommend_batch_size : int, default 256
How many samples per batch to load during `recommend`.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_batch_size` attribute.
recommend_accelerator : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"}, default "auto"
Accelerator type for `recommend`. Used at predict_step of lightning module.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_accelerator` attribute.
recommend_devices : int | List[int], default 1
Devices for `recommend`. Please note that multi-device inference is not supported!
Do not specify more then one device. For ``gpu`` accelerator you can pass which device to
use, e.g. ``[1]``.
Used at predict_step of lightning module.
Multi-device recommendations are not supported.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_device` attribute.
recommend_n_threads : int, default 0
Number of threads to use in ranker if GPU ranking is turned off or unavailable.
If you want to change this parameter after model is initialized,
Expand Down Expand Up @@ -222,38 +242,44 @@ class BERT4RecModel(TransformerModelBase):
Function to get validation mask.
"""

config_class = BERT4RecModelConfig

def __init__( # pylint: disable=too-many-arguments, too-many-locals
self,
n_blocks: int = 1,
n_heads: int = 1,
n_factors: int = 128,
n_blocks: int = 2,
n_heads: int = 4,
n_factors: int = 256,
use_pos_emb: bool = True,
use_causal_attn: bool = False,
use_key_padding_mask: bool = True,
dropout_rate: float = 0.2,
epochs: int = 3,
verbose: int = 0,
deterministic: bool = False,
recommend_device: Union[str, Accelerator] = "auto",
recommend_batch_size: int = 256,
recommend_accelerator: str = "auto",
recommend_devices: tp.Union[int, tp.List[int]] = 1,
recommend_n_threads: int = 0,
recommend_use_gpu_ranking: bool = True,
session_max_len: int = 32,
session_max_len: int = 100,
n_negatives: int = 1,
batch_size: int = 128,
loss: str = "softmax",
gbce_t: float = 0.2,
lr: float = 0.01,
lr: float = 0.001,
dataloader_num_workers: int = 0,
train_min_user_interactions: int = 2,
mask_prob: float = 0.15,
trainer: tp.Optional[Trainer] = None,
item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet),
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers,
data_preparator_type: tp.Type[BERT4RecDataPreparator] = BERT4RecDataPreparator,
data_preparator_type: tp.Type[SessionEncoderDataPreparatorBase] = BERT4RecDataPreparator,
lightning_module_type: tp.Type[SessionEncoderLightningModuleBase] = SessionEncoderLightningModule,
get_val_mask_func: tp.Optional[tp.Callable] = None,
):
self.mask_prob = mask_prob

super().__init__(
transformer_layers_type=transformer_layers_type,
data_preparator_type=data_preparator_type,
Expand All @@ -264,28 +290,37 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
use_causal_attn=use_causal_attn,
use_key_padding_mask=use_key_padding_mask,
dropout_rate=dropout_rate,
session_max_len=session_max_len,
dataloader_num_workers=dataloader_num_workers,
batch_size=batch_size,
loss=loss,
n_negatives=n_negatives,
gbce_t=gbce_t,
lr=lr,
epochs=epochs,
verbose=verbose,
deterministic=deterministic,
recommend_device=recommend_device,
recommend_batch_size=recommend_batch_size,
recommend_accelerator=recommend_accelerator,
recommend_devices=recommend_devices,
recommend_n_threads=recommend_n_threads,
recommend_use_gpu_ranking=recommend_use_gpu_ranking,
loss=loss,
gbce_t=gbce_t,
lr=lr,
session_max_len=session_max_len + 1,
train_min_user_interactions=train_min_user_interactions,
trainer=trainer,
item_net_block_types=item_net_block_types,
pos_encoding_type=pos_encoding_type,
lightning_module_type=lightning_module_type,
get_val_mask_func=get_val_mask_func,
)
self.data_preparator = data_preparator_type(
session_max_len=session_max_len,
n_negatives=n_negatives if loss != "softmax" else None,
batch_size=batch_size,
dataloader_num_workers=dataloader_num_workers,
train_min_user_interactions=train_min_user_interactions,

def _init_data_preparator(self) -> None:
self.data_preparator: SessionEncoderDataPreparatorBase = self.data_preparator_type(
session_max_len=self.session_max_len - 1, # TODO: remove `-1`
n_negatives=self.n_negatives if self.loss != "softmax" else None,
batch_size=self.batch_size,
dataloader_num_workers=self.dataloader_num_workers,
train_min_user_interactions=self.train_min_user_interactions,
item_extra_tokens=(PADDING_VALUE, MASKING_VALUE),
mask_prob=mask_prob,
get_val_mask_func=get_val_mask_func,
mask_prob=self.mask_prob,
get_val_mask_func=self.get_val_mask_func,
)
2 changes: 0 additions & 2 deletions rectools/models/nn/item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
torch.Tensor
Item embeddings.
"""
# TODO: Should we use torch.nn.EmbeddingBag?
feature_dense = self.get_dense_item_features(items)

feature_embs = self.category_embeddings(self.feature_catalog.to(self.device))
Expand Down Expand Up @@ -252,7 +251,6 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
Item embeddings.
"""
item_embs = []
# TODO: Add functionality for parallel computing.
for idx_block in range(self.n_item_blocks):
item_emb = self.item_net_blocks[idx_block](items)
item_embs.append(item_emb)
Expand Down
73 changes: 54 additions & 19 deletions rectools/models/nn/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@
# limitations under the License.

import typing as tp
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple

import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import Accelerator
from torch import nn

from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase
from .transformer_base import (
PADDING_VALUE,
SessionEncoderDataPreparatorType,
SessionEncoderLightningModule,
SessionEncoderLightningModuleBase,
TransformerLayersType,
TransformerModelBase,
TransformerModelConfig,
)
from .transformer_data_preparator import SessionEncoderDataPreparatorBase
from .transformer_net_blocks import (
Expand Down Expand Up @@ -183,7 +185,15 @@ def forward(
return seqs


class SASRecModel(TransformerModelBase):
class SASRecModelConfig(TransformerModelConfig):
"""SASRecModel config."""

data_preparator_type: SessionEncoderDataPreparatorType = SASRecDataPreparator
transformer_layers_type: TransformerLayersType = SASRecTransformerLayers
use_causal_attn: bool = True


class SASRecModel(TransformerModelBase[SASRecModelConfig]):
"""
SASRec model.

Expand Down Expand Up @@ -227,8 +237,20 @@ class SASRecModel(TransformerModelBase):
deterministic : bool, default ``False``
If ``True``, set deterministic algorithms for PyTorch operations.
Use `pytorch_lightning.seed_everything` together with this parameter to fix the random state.
recommend_device : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"} or Accelerator, default "auto"
Device for recommend. Used at predict_step of lightning module.
recommend_batch_size : int, default 256
How many samples per batch to load during `recommend`.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_batch_size` attribute.
recommend_accelerator : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"}, default "auto"
Accelerator type for `recommend`. Used at predict_step of lightning module.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_accelerator` attribute.
recommend_devices : int | List[int], default 1
Devices for `recommend`. Please note that multi-device inference is not supported!
Do not specify more then one device. For ``gpu`` accelerator you can pass which device to
use, e.g. ``[1]``.
Used at predict_step of lightning module.
Multi-device recommendations are not supported.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_device` attribute.
recommend_n_threads : int, default 0
Expand Down Expand Up @@ -259,26 +281,30 @@ class SASRecModel(TransformerModelBase):
Function to get validation mask.
"""

config_class = SASRecModelConfig

def __init__( # pylint: disable=too-many-arguments, too-many-locals
self,
n_blocks: int = 1,
n_heads: int = 1,
n_factors: int = 128,
n_blocks: int = 2,
n_heads: int = 4,
n_factors: int = 256,
use_pos_emb: bool = True,
use_causal_attn: bool = True,
use_key_padding_mask: bool = False,
dropout_rate: float = 0.2,
session_max_len: int = 32,
session_max_len: int = 100,
dataloader_num_workers: int = 0,
batch_size: int = 128,
loss: str = "softmax",
n_negatives: int = 1,
gbce_t: float = 0.2,
lr: float = 0.01,
lr: float = 0.001,
epochs: int = 3,
verbose: int = 0,
deterministic: bool = False,
recommend_device: Union[str, Accelerator] = "auto",
recommend_batch_size: int = 256,
recommend_accelerator: str = "auto",
recommend_devices: tp.Union[int, tp.List[int]] = 1,
recommend_n_threads: int = 0,
recommend_use_gpu_ranking: bool = True,
train_min_user_interactions: int = 2,
Expand All @@ -301,26 +327,35 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
use_key_padding_mask=use_key_padding_mask,
dropout_rate=dropout_rate,
session_max_len=session_max_len,
dataloader_num_workers=dataloader_num_workers,
batch_size=batch_size,
loss=loss,
n_negatives=n_negatives,
gbce_t=gbce_t,
lr=lr,
epochs=epochs,
verbose=verbose,
deterministic=deterministic,
recommend_device=recommend_device,
recommend_batch_size=recommend_batch_size,
recommend_accelerator=recommend_accelerator,
recommend_devices=recommend_devices,
recommend_n_threads=recommend_n_threads,
recommend_use_gpu_ranking=recommend_use_gpu_ranking,
train_min_user_interactions=train_min_user_interactions,
trainer=trainer,
item_net_block_types=item_net_block_types,
pos_encoding_type=pos_encoding_type,
lightning_module_type=lightning_module_type,
get_val_mask_func=get_val_mask_func,
)
self.data_preparator = data_preparator_type(
session_max_len=session_max_len,
n_negatives=n_negatives if loss != "softmax" else None,
batch_size=batch_size,
dataloader_num_workers=dataloader_num_workers,

def _init_data_preparator(self) -> None:
self.data_preparator = self.data_preparator_type(
session_max_len=self.session_max_len,
n_negatives=self.n_negatives if self.loss != "softmax" else None,
batch_size=self.batch_size,
dataloader_num_workers=self.dataloader_num_workers,
item_extra_tokens=(PADDING_VALUE,),
train_min_user_interactions=train_min_user_interactions,
get_val_mask_func=get_val_mask_func,
train_min_user_interactions=self.train_min_user_interactions,
get_val_mask_func=self.get_val_mask_func,
)
Loading