Skip to content

Commit

Permalink
Feature/lightning recommend (#256)
Browse files Browse the repository at this point in the history
Moved recommend to lightning module: helps introduce custom tying of
user and item embeddings
Introduced item net constructor type
  • Loading branch information
blondered authored Feb 4, 2025
1 parent 823377d commit 0064ace
Show file tree
Hide file tree
Showing 9 changed files with 782 additions and 614 deletions.
33 changes: 17 additions & 16 deletions rectools/models/nn/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
import torch

from .constants import MASKING_VALUE, PADDING_VALUE
from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase
from .item_net import (
CatFeaturesItemNet,
IdEmbeddingsItemNet,
ItemNetBase,
ItemNetConstructorBase,
SumOfEmbeddingsConstructor,
)
from .transformer_base import (
TrainerCallable,
TransformerDataPreparatorType,
Expand Down Expand Up @@ -234,6 +240,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
(IdEmbeddingsItemNet,) - item embeddings based on ids.
(CatFeaturesItemNet,) - item embeddings based on categorical features.
(IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features.
item_net_constructor_type : type(ItemNetConstructorBase), default `SumOfEmbeddingsConstructor`
Type of item net blocks aggregation constructor.
pos_encoding_type : type(PositionalEncodingBase), default `LearnableInversePositionalEncoding`
Type of positional encoding.
transformer_layers_type : type(TransformerLayersBase), default `PreLNTransformerLayers`
Expand All @@ -255,16 +263,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
How many samples per batch to load during `recommend`.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_batch_size` attribute.
recommend_accelerator : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"}, default "auto"
Accelerator type for `recommend`. Used at predict_step of lightning module.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_accelerator` attribute.
recommend_devices : int | List[int], default 1
Devices for `recommend`. Please note that multi-device inference is not supported!
Do not specify more then one device. For ``gpu`` accelerator you can pass which device to
use, e.g. ``[1]``.
Used at predict_step of lightning module.
Multi-device recommendations are not supported.
recommend_device : {"cpu", "cuda", "cuda:0", ...}, default ``None``
String representation for `torch.device` used for recommendations.
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_device` attribute.
recommend_n_threads : int, default 0
Expand Down Expand Up @@ -301,17 +302,17 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
use_key_padding_mask: bool = True,
use_causal_attn: bool = False,
item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet),
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers,
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator,
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
recommend_batch_size: int = 256,
recommend_accelerator: str = "auto",
recommend_devices: tp.Union[int, tp.List[int]] = 1,
recommend_device: tp.Optional[str] = None,
recommend_n_threads: int = 0,
recommend_use_gpu_ranking: bool = True,
recommend_use_gpu_ranking: bool = True, # TODO: remove after TorchRanker
):
self.mask_prob = mask_prob

Expand All @@ -336,12 +337,12 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
verbose=verbose,
deterministic=deterministic,
recommend_batch_size=recommend_batch_size,
recommend_accelerator=recommend_accelerator,
recommend_devices=recommend_devices,
recommend_device=recommend_device,
recommend_n_threads=recommend_n_threads,
recommend_use_gpu_ranking=recommend_use_gpu_ranking,
train_min_user_interactions=train_min_user_interactions,
item_net_block_types=item_net_block_types,
item_net_constructor_type=item_net_constructor_type,
pos_encoding_type=pos_encoding_type,
lightning_module_type=lightning_module_type,
get_val_mask_func=get_val_mask_func,
Expand Down
71 changes: 50 additions & 21 deletions rectools/models/nn/item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def from_dataset_schema(cls, dataset_schema: DatasetSchema, n_factors: int, drop
return cls(n_factors, n_items, dropout_rate)


class ItemNetConstructor(ItemNetBase):
class ItemNetConstructorBase(ItemNetBase):
"""
Constructed network for item embeddings based on aggregation of embeddings from transferred item network types.
Expand All @@ -257,26 +257,6 @@ def __init__(
self.n_item_blocks = len(item_net_blocks)
self.item_net_blocks = nn.ModuleList(item_net_blocks)

def forward(self, items: torch.Tensor) -> torch.Tensor:
"""
Forward pass to get item embeddings from item network blocks.
Parameters
----------
items : torch.Tensor
Internal item ids.
Returns
-------
torch.Tensor
Item embeddings.
"""
item_embs = []
for idx_block in range(self.n_item_blocks):
item_emb = self.item_net_blocks[idx_block](items)
item_embs.append(item_emb)
return torch.sum(torch.stack(item_embs, dim=0), dim=0)

@property
def catalog(self) -> torch.Tensor:
"""Return tensor with elements in range [0, n_items)."""
Expand Down Expand Up @@ -336,3 +316,52 @@ def from_dataset_schema(
item_net_blocks.append(item_net_block)

return cls(n_items, item_net_blocks)

def forward(self, items: torch.Tensor) -> torch.Tensor:
"""Forward pass through item net blocks and aggregation of the results.
Parameters
----------
items : torch.Tensor
Internal item ids.
Returns
-------
torch.Tensor
Item embeddings.
"""
raise NotImplementedError()


class SumOfEmbeddingsConstructor(ItemNetConstructorBase):
"""
Item net blocks constructor that simply sums all of the its net blocks embeddings.
Parameters
----------
n_items : int
Number of items in the dataset.
item_net_blocks : Sequence(ItemNetBase)
Latent embedding size of item embeddings.
"""

def forward(self, items: torch.Tensor) -> torch.Tensor:
"""
Forward pass through item net blocks and aggregation of the results.
Simple sum of embeddings.
Parameters
----------
items : torch.Tensor
Internal item ids.
Returns
-------
torch.Tensor
Item embeddings.
"""
item_embs = []
for idx_block in range(self.n_item_blocks):
item_emb = self.item_net_blocks[idx_block](items)
item_embs.append(item_emb)
return torch.sum(torch.stack(item_embs, dim=0), dim=0)
33 changes: 17 additions & 16 deletions rectools/models/nn/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
import torch
from torch import nn

from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase
from .item_net import (
CatFeaturesItemNet,
IdEmbeddingsItemNet,
ItemNetBase,
ItemNetConstructorBase,
SumOfEmbeddingsConstructor,
)
from .transformer_base import (
TrainerCallable,
TransformerDataPreparatorType,
Expand Down Expand Up @@ -263,6 +269,8 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
(IdEmbeddingsItemNet,) - item embeddings based on ids.
(CatFeaturesItemNet,) - item embeddings based on categorical features.
(IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features.
item_net_constructor_type : type(ItemNetConstructorBase), default `SumOfEmbeddingsConstructor`
Type of item net blocks aggregation constructor.
pos_encoding_type : type(PositionalEncodingBase), default `LearnableInversePositionalEncoding`
Type of positional encoding.
transformer_layers_type : type(TransformerLayersBase), default `SasRecTransformerLayers`
Expand All @@ -284,16 +292,9 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
How many samples per batch to load during `recommend`.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_batch_size` attribute.
recommend_accelerator : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"}, default "auto"
Accelerator type for `recommend`. Used at predict_step of lightning module.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_accelerator` attribute.
recommend_devices : int | List[int], default 1
Devices for `recommend`. Please note that multi-device inference is not supported!
Do not specify more then one device. For ``gpu`` accelerator you can pass which device to
use, e.g. ``[1]``.
Used at predict_step of lightning module.
Multi-device recommendations are not supported.
recommend_device : {"cpu", "cuda", "cuda:0", ...}, default ``None``
String representation for `torch.device` used for recommendations.
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_device` attribute.
recommend_n_threads : int, default 0
Expand Down Expand Up @@ -329,17 +330,17 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
use_key_padding_mask: bool = False,
use_causal_attn: bool = True,
item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet),
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, # SASRec authors net
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator,
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
recommend_batch_size: int = 256,
recommend_accelerator: str = "auto",
recommend_devices: tp.Union[int, tp.List[int]] = 1,
recommend_device: tp.Optional[str] = None,
recommend_n_threads: int = 0,
recommend_use_gpu_ranking: bool = True,
recommend_use_gpu_ranking: bool = True, # TODO: remove after TorchRanker
):
super().__init__(
transformer_layers_type=transformer_layers_type,
Expand All @@ -362,12 +363,12 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
verbose=verbose,
deterministic=deterministic,
recommend_batch_size=recommend_batch_size,
recommend_accelerator=recommend_accelerator,
recommend_devices=recommend_devices,
recommend_device=recommend_device,
recommend_n_threads=recommend_n_threads,
recommend_use_gpu_ranking=recommend_use_gpu_ranking,
train_min_user_interactions=train_min_user_interactions,
item_net_block_types=item_net_block_types,
item_net_constructor_type=item_net_constructor_type,
pos_encoding_type=pos_encoding_type,
lightning_module_type=lightning_module_type,
get_val_mask_func=get_val_mask_func,
Expand Down
Loading

0 comments on commit 0064ace

Please sign in to comment.