Skip to content

Commit

Permalink
Merge branch 'experimental/sasrec' into feature/transformer_tutorial_…
Browse files Browse the repository at this point in the history
…update
  • Loading branch information
spirinamayya committed Feb 11, 2025
2 parents 5d368df + 1283bda commit f506ce4
Show file tree
Hide file tree
Showing 30 changed files with 1,790 additions and 554 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## Unreleased

### Added
- `SASRecModel` and `BERT4RecModel` - models based on transformer architecture ([#220](https://github.com/MobileTeleSystems/RecTools/pull/220))
- Transfomers extended theory & practice tuotorial ([#220](https://github.com/MobileTeleSystems/RecTools/pull/220))
- Transfomers advanced training guide ([#220](https://github.com/MobileTeleSystems/RecTools/pull/220))
- `use_gpu` for PureSVD ([#229](https://github.com/MobileTeleSystems/RecTools/pull/229))
- `from_params` method for models and `model_from_params` function ([#252](https://github.com/MobileTeleSystems/RecTools/pull/252))
- `TorchRanker` ranker which calculates scores using torch. Supports GPU. [#251](https://github.com/MobileTeleSystems/RecTools/pull/251)
- `Ranker` ranker protocol which unify rankers call. [#251](https://github.com/MobileTeleSystems/RecTools/pull/251)

### Changed

- `ImplicitRanker` `rank` method compatible with `Ranker` protocol. `use_gpu` and `num_threads` params moved from `rank` method to `__init__`. [#251](https://github.com/MobileTeleSystems/RecTools/pull/251)

## [0.10.0] - 16.01.2025

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ faster than ever before.
- Fully compatible with our `fit` / `recommend` paradigm and require NO special data processing
- Explicitly described in our [Transformers Theory & Practice Tutorial](examples/tutorials/transformers_tutorial.ipynb): loss options, item embedding options, category features utilization and more!
- Configurable, customizable, callback-friendly, checkpoints-included, logs-out-of-the-box, custom-validation-ready, multi-gpu-compatible! See our [Transformers Advanced Training User Guide](examples/tutorials/transformers_advanced_training_guide.ipynb)
- We are running benchmarks with comparison of RecTools models to other open-source implementations following BERT4Rec reproducibility paper and achieve highest scores on multiple datasets: [Performance on public transformers benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results)
- Public benchmarks which compare of RecTools models to other open-source implementations following BERT4Rec replicability paper show that RecTools implementations achieve highest scores on multiple datasets: [Performance on public transformers benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results)



Expand Down Expand Up @@ -214,6 +214,7 @@ make clean
- [Maya Spirina](https://github.com/spirinamayya)
- [Grigoriy Gusarov](https://github.com/Gooogr)
- [Aki Ariga](https://github.com/chezou)
- [Nikolay Undalov](https://github.com/nsundalov)

Previous contributors: [Ildar Safilo](https://github.com/irsafilo) [ex-Maintainer], [Daniil Potapov](https://github.com/sharthZ23) [ex-Maintainer], [Alexander Butenko](https://github.com/iomallach), [Igor Belkov](https://github.com/OzmundSedler), [Artem Senin](https://github.com/artemseninhse), [Mikhail Khasykov](https://github.com/mkhasykov), [Julia Karamnova](https://github.com/JuliaKup), [Maxim Lukin](https://github.com/groundmax), [Yuri Ulianov](https://github.com/yukeeul), [Egor Kratkov](https://github.com/jegorus), [Azat Sibagatulin](https://github.com/azatnv), [Vadim Vetrov](https://github.com/Waujito)

3 changes: 2 additions & 1 deletion rectools/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def _serialize_feature_name(spec: tp.Any) -> Hashable:
return tuple(_serialize_feature_name(item) for item in spec)
if isinstance(spec, (int, float, str, bool)):
return spec
if np.issubdtype(spec, np.number) or np.issubdtype(spec, np.bool_): # str is handled by isinstance(spec, str)
if hasattr(spec, "dtype") and (np.issubdtype(spec.dtype, np.number) or np.issubdtype(spec.dtype, np.bool_)):
# numpy str is handled by isinstance(spec, str)
return spec.item()
raise type_error

Expand Down
6 changes: 3 additions & 3 deletions rectools/models/ease.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(
recommend_use_gpu_ranking: bool = True,
verbose: int = 0,
):

super().__init__(verbose=verbose)
self.weight: np.ndarray
self.regularization = regularization
Expand Down Expand Up @@ -146,16 +145,17 @@ def _recommend_u2i(
distance=Distance.DOT,
subjects_factors=user_items,
objects_factors=self.weight,
use_gpu=self.recommend_use_gpu_ranking and HAS_CUDA,
num_threads=self.recommend_n_threads,
)

ui_csr_for_filter = user_items[user_ids] if filter_viewed else None

all_user_ids, all_reco_ids, all_scores = ranker.rank(
subject_ids=user_ids,
k=k,
filter_pairs_csr=ui_csr_for_filter,
sorted_object_whitelist=sorted_item_ids_to_recommend,
num_threads=self.recommend_n_threads,
use_gpu=self.recommend_use_gpu_ranking and HAS_CUDA,
)

return all_user_ids, all_reco_ids, all_scores
Expand Down
16 changes: 9 additions & 7 deletions rectools/models/nn/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,16 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_device` attribute.
recommend_use_torch_ranking : bool, default ``True``
Use `TorchRanker` for items ranking while preparing recommendations.
If set to ``False``, use `ImplicitRanker` instead.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_use_torch_ranking` attribute.
recommend_n_threads : int, default 0
Number of threads to use in ranker if GPU ranking is turned off or unavailable.
Number of threads to use for `ImplicitRanker`. Omitted if `recommend_use_torch_ranking` is
set to ``True`` (default).
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_n_threads` attribute.
recommend_use_gpu_ranking : bool, default ``True``
If ``True`` and HAS_CUDA ``True``, set use_gpu=True in ImplicitRanker.rank.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_use_gpu_ranking` attribute.
"""

config_class = BERT4RecModelConfig
Expand Down Expand Up @@ -311,8 +313,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
get_trainer_func: tp.Optional[TrainerCallable] = None,
recommend_batch_size: int = 256,
recommend_device: tp.Optional[str] = None,
recommend_use_torch_ranking: bool = True,
recommend_n_threads: int = 0,
recommend_use_gpu_ranking: bool = True, # TODO: remove after TorchRanker
):
self.mask_prob = mask_prob

Expand All @@ -339,7 +341,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
recommend_batch_size=recommend_batch_size,
recommend_device=recommend_device,
recommend_n_threads=recommend_n_threads,
recommend_use_gpu_ranking=recommend_use_gpu_ranking,
recommend_use_torch_ranking=recommend_use_torch_ranking,
train_min_user_interactions=train_min_user_interactions,
item_net_block_types=item_net_block_types,
item_net_constructor_type=item_net_constructor_type,
Expand Down
10 changes: 5 additions & 5 deletions rectools/models/nn/item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:

def get_item_inputs_offsets(self, items: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Get categorical item features and offsets for `items`."""
length_range = torch.arange(self.input_lengths.max().item(), device=self.device)
item_indexes = self.offsets[items].unsqueeze(-1) + length_range
length_mask = length_range < self.input_lengths[items].unsqueeze(-1)
item_emb_bag_inputs = self.emb_bag_inputs[item_indexes[length_mask]]
length_range = torch.arange(self.get_buffer("input_lengths").max().item(), device=self.device)
item_indexes = self.get_buffer("offsets")[items].unsqueeze(-1) + length_range
length_mask = length_range < self.get_buffer("input_lengths")[items].unsqueeze(-1)
item_emb_bag_inputs = self.get_buffer("emb_bag_inputs")[item_indexes[length_mask]]
item_offsets = torch.cat(
(torch.tensor([0], device=self.device), torch.cumsum(self.input_lengths[items], dim=0)[:-1])
(torch.tensor([0], device=self.device), torch.cumsum(self.get_buffer("input_lengths")[items], dim=0)[:-1])
)
return item_emb_bag_inputs, item_offsets

Expand Down
125 changes: 86 additions & 39 deletions rectools/models/nn/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,73 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> D
return {"x": torch.LongTensor(x)}


class SASRecTransformerLayer(nn.Module):
"""
Exactly SASRec author's transformer block architecture but with pytorch Multi-Head Attention realisation.
Parameters
----------
n_factors : int
Latent embeddings size.
n_heads : int
Number of attention heads.
dropout_rate : float
Probability of a hidden unit to be zeroed.
"""

def __init__(
self,
n_factors: int,
n_heads: int,
dropout_rate: float,
):
super().__init__()
# important: original architecture had another version of MHA
self.multi_head_attn = torch.nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True)
self.q_layer_norm = nn.LayerNorm(n_factors)
self.ff_layer_norm = nn.LayerNorm(n_factors)
self.feed_forward = PointWiseFeedForward(n_factors, n_factors, dropout_rate, torch.nn.ReLU())
self.dropout = torch.nn.Dropout(dropout_rate)

def forward(
self,
seqs: torch.Tensor,
attn_mask: tp.Optional[torch.Tensor],
key_padding_mask: tp.Optional[torch.Tensor],
) -> torch.Tensor:
"""
Forward pass through transformer block.
Parameters
----------
seqs : torch.Tensor
User sequences of item embeddings.
attn_mask : torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `attn_mask`.
key_padding_mask : torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.
Returns
-------
torch.Tensor
User sequences passed through transformer layers.
"""
q = self.q_layer_norm(seqs)
mha_output, _ = self.multi_head_attn(
q, seqs, seqs, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False
)
seqs = q + mha_output
ff_input = self.ff_layer_norm(seqs)
seqs = self.feed_forward(ff_input)
seqs = self.dropout(seqs)
seqs += ff_input
return seqs


class SASRecTransformerLayers(TransformerLayersBase):
"""
Exactly SASRec author's transformer blocks architecture but with pytorch Multi-Head Attention realisation.
SASRec transformer blocks.
Parameters
----------
Expand All @@ -137,15 +201,16 @@ def __init__(
):
super().__init__()
self.n_blocks = n_blocks
self.multi_head_attn = nn.ModuleList(
[torch.nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True) for _ in range(n_blocks)]
) # important: original architecture had another version of MHA
self.q_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)])
self.ff_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)])
self.feed_forward = nn.ModuleList(
[PointWiseFeedForward(n_factors, n_factors, dropout_rate, torch.nn.ReLU()) for _ in range(n_blocks)]
self.transformer_blocks = nn.ModuleList(
[
SASRecTransformerLayer(
n_factors,
n_heads,
dropout_rate,
)
for _ in range(self.n_blocks)
]
)
self.dropout = nn.ModuleList([torch.nn.Dropout(dropout_rate) for _ in range(n_blocks)])
self.last_layernorm = torch.nn.LayerNorm(n_factors, eps=1e-8)

def forward(
Expand Down Expand Up @@ -175,21 +240,11 @@ def forward(
torch.Tensor
User sequences passed through transformer layers.
"""
seqs *= timeline_mask # [batch_size, session_max_len, n_factors]
for i in range(self.n_blocks):
q = self.q_layer_norm[i](seqs)
mha_output, _ = self.multi_head_attn[i](
q, seqs, seqs, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False
)
seqs = q + mha_output
ff_input = self.ff_layer_norm[i](seqs)
seqs = self.feed_forward[i](ff_input)
seqs = self.dropout[i](seqs)
seqs += ff_input
seqs *= timeline_mask

seqs *= timeline_mask # [batch_size, session_max_len, n_factors]
seqs = self.transformer_blocks[i](seqs, attn_mask, key_padding_mask)
seqs *= timeline_mask
seqs = self.last_layernorm(seqs)

return seqs


Expand Down Expand Up @@ -297,14 +352,16 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_device` attribute.
recommend_use_torch_ranking : bool, default ``True``
Use `TorchRanker` for items ranking while preparing recommendations.
If set to ``False``, use `ImplicitRanker` instead.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_use_torch_ranking` attribute.
recommend_n_threads : int, default 0
Number of threads to use in ranker if GPU ranking is turned off or unavailable.
Number of threads to use for `ImplicitRanker`. Omitted if `recommend_use_torch_ranking` is
set to ``True`` (default).
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_n_threads` attribute.
recommend_use_gpu_ranking : bool, default ``True``
If ``True`` and HAS_CUDA ``True``, set use_gpu=True in ImplicitRanker.rank.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_use_gpu_ranking` attribute.
"""

config_class = SASRecModelConfig
Expand Down Expand Up @@ -339,8 +396,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
get_trainer_func: tp.Optional[TrainerCallable] = None,
recommend_batch_size: int = 256,
recommend_device: tp.Optional[str] = None,
recommend_use_torch_ranking: bool = True,
recommend_n_threads: int = 0,
recommend_use_gpu_ranking: bool = True, # TODO: remove after TorchRanker
):
super().__init__(
transformer_layers_type=transformer_layers_type,
Expand All @@ -365,7 +422,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
recommend_batch_size=recommend_batch_size,
recommend_device=recommend_device,
recommend_n_threads=recommend_n_threads,
recommend_use_gpu_ranking=recommend_use_gpu_ranking,
recommend_use_torch_ranking=recommend_use_torch_ranking,
train_min_user_interactions=train_min_user_interactions,
item_net_block_types=item_net_block_types,
item_net_constructor_type=item_net_constructor_type,
Expand All @@ -374,13 +431,3 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
)

def _init_data_preparator(self) -> None:
self.data_preparator = self.data_preparator_type(
session_max_len=self.session_max_len,
n_negatives=self.n_negatives if self.loss != "softmax" else None,
batch_size=self.batch_size,
dataloader_num_workers=self.dataloader_num_workers,
train_min_user_interactions=self.train_min_user_interactions,
get_val_mask_func=self.get_val_mask_func,
)
Loading

0 comments on commit f506ce4

Please sign in to comment.