Skip to content

Commit

Permalink
nn modules tests (#247)
Browse files Browse the repository at this point in the history
Added docstrings and tests for nn modules

---------

Co-authored-by: Daria Tikhonovich <daria.m.tikhonovich@gmail.com>
Co-authored-by: blondered <nykenott@gmail.com>
  • Loading branch information
3 people authored Jan 29, 2025
1 parent a062d13 commit 24a7877
Show file tree
Hide file tree
Showing 8 changed files with 1,069 additions and 161 deletions.
42 changes: 27 additions & 15 deletions rectools/models/nn/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
class BERT4RecDataPreparator(SessionEncoderDataPreparatorBase):
"""Data Preparator for BERT4RecModel."""

train_session_max_len_addition: int = 0

def __init__(
self,
session_max_len: int,
Expand Down Expand Up @@ -86,16 +88,22 @@ def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
) -> Dict[str, torch.Tensor]:
"""TODO"""
"""
Mask session elements to receive `x`.
Get target by replacing session elements with a MASK token with probability `mask_prob`.
Truncate each session and target from right to keep `session_max_len` last items.
Do left padding until `session_max_len` is reached.
If `n_negatives` is not None, generate negative items from uniform distribution.
"""
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len + 1))
y = np.zeros((batch_size, self.session_max_len + 1))
yw = np.zeros((batch_size, self.session_max_len + 1))
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
masked_session, target = self._mask_session(ses)
x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len + 1]
y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len + 1]
yw[i, -len(ses) :] = ses_weights # ses_weights: [session_len] -> yw[i]: [session_max_len + 1]
x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len]
yw[i, -len(ses) :] = ses_weights # ses_weights: [session_len] -> yw[i]: [session_max_len]

batch_dict = {"x": torch.LongTensor(x), "y": torch.LongTensor(y), "yw": torch.FloatTensor(yw)}
if self.n_negatives is not None:
Expand All @@ -109,7 +117,7 @@ def _collate_fn_train(

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len + 1))
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # until only leave-one-strategy
yw = np.zeros((batch_size, 1)) # until only leave-one-strategy
for i, (ses, ses_weights) in enumerate(batch):
Expand All @@ -120,8 +128,8 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
session = session + [self.extra_token_ids[MASKING_VALUE]]
target_idx = [idx for idx, weight in enumerate(ses_weights) if weight != 0][0]

Check warning on line 129 in rectools/models/nn/bert4rec.py

View check run for this annotation

Codecov / codecov/patch

rectools/models/nn/bert4rec.py#L128-L129

Added lines #L128 - L129 were not covered by tests

# ses: [session_len] -> x[i]: [session_max_len + 1]
x[i, -len(input_session) - 1 :] = session[-self.session_max_len - 1 :]
# ses: [session_len] -> x[i]: [session_max_len]
x[i, -len(input_session) - 1 :] = session[-self.session_max_len :]
y[i, -1:] = ses[target_idx] # y[i]: [1]
yw[i, -1:] = ses_weights[target_idx] # yw[i]: [1]

Check warning on line 134 in rectools/models/nn/bert4rec.py

View check run for this annotation

Codecov / codecov/patch

rectools/models/nn/bert4rec.py#L132-L134

Added lines #L132 - L134 were not covered by tests

Expand All @@ -136,12 +144,16 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
return batch_dict

Check warning on line 144 in rectools/models/nn/bert4rec.py

View check run for this annotation

Codecov / codecov/patch

rectools/models/nn/bert4rec.py#L143-L144

Added lines #L143 - L144 were not covered by tests

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
"""Right truncation, left padding to session_max_len"""
x = np.zeros((len(batch), self.session_max_len + 1))
"""
Right truncation, left padding to `session_max_len`
During inference model will use (`session_max_len` - 1) interactions
and one extra "MASK" token will be added for making predictions.
"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
session = ses.copy()
session = session + [self.extra_token_ids[MASKING_VALUE]]
x[i, -len(ses) - 1 :] = session[-self.session_max_len - 1 :]
x[i, -len(ses) - 1 :] = session[-self.session_max_len :]
return {"x": torch.LongTensor(x)}


Expand Down Expand Up @@ -228,7 +240,7 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
item_net_block_types : sequence of `type(ItemNetBase)`, default `(IdEmbeddingsItemNet, CatFeaturesItemNet)`
Type of network returning item embeddings.
(IdEmbeddingsItemNet,) - item embeddings based on ids.
(, CatFeaturesItemNet) - item embeddings based on categorical features.
(CatFeaturesItemNet,) - item embeddings based on categorical features.
(IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features.
pos_encoding_type : type(PositionalEncodingBase), default `LearnableInversePositionalEncoding`
Type of positional encoding.
Expand Down Expand Up @@ -315,7 +327,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals

def _init_data_preparator(self) -> None:
self.data_preparator: SessionEncoderDataPreparatorBase = self.data_preparator_type(
session_max_len=self.session_max_len - 1, # TODO: remove `-1`
session_max_len=self.session_max_len,
n_negatives=self.n_negatives if self.loss != "softmax" else None,
batch_size=self.batch_size,
dataloader_num_workers=self.dataloader_num_workers,
Expand Down
14 changes: 12 additions & 2 deletions rectools/models/nn/item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,18 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:

@classmethod
def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> tpe.Self:
"""TODO"""
"""
Create IdEmbeddingsItemNet from RecTools dataset.
Parameters
----------
dataset : Dataset
RecTools dataset.
n_factors : int
Latent embedding size of item embeddings.
dropout_rate : float
Probability of a hidden unit of item embedding to be zeroed.
"""
n_items = dataset.item_id_map.size
return cls(n_factors, n_items, dropout_rate)

Expand All @@ -226,7 +237,6 @@ def __init__(
n_items: int,
item_net_blocks: tp.Sequence[ItemNetBase],
) -> None:
"""TODO"""
super().__init__()

if len(item_net_blocks) == 0:
Expand Down
8 changes: 5 additions & 3 deletions rectools/models/nn/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@
class SASRecDataPreparator(SessionEncoderDataPreparatorBase):
"""Data preparator for SASRecModel."""

train_session_max_len_addition: int = 1

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
) -> Dict[str, torch.Tensor]:
"""
Truncate each session from right to keep (session_max_len+1) last items.
Do left padding until (session_max_len+1) is reached.
Truncate each session from right to keep `session_max_len` items.
Do left padding until `session_max_len` is reached.
Split to `x`, `y`, and `yw`.
"""
batch_size = len(batch)
Expand Down Expand Up @@ -267,7 +269,7 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
item_net_block_types : sequence of `type(ItemNetBase)`, default `(IdEmbeddingsItemNet, CatFeaturesItemNet)`
Type of network returning item embeddings.
(IdEmbeddingsItemNet,) - item embeddings based on ids.
(, CatFeaturesItemNet) - item embeddings based on categorical features.
(CatFeaturesItemNet,) - item embeddings based on categorical features.
(IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features.
pos_encoding_type : type(PositionalEncodingBase), default `LearnableInversePositionalEncoding`
Type of positional encoding.
Expand Down
4 changes: 2 additions & 2 deletions rectools/models/nn/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,8 +765,8 @@ def _recommend_u2i(
recommend_dataloader = self.data_preparator.get_dataloader_recommend(dataset, self.recommend_batch_size)

session_embs = recommend_trainer.predict(model=self.lightning_model, dataloaders=recommend_dataloader)
if session_embs is None:
explanation = """Received empty recommendations."""
if session_embs is None: # pragma: no cover
explanation = """Received empty recommendations. Used to solve incompatible type linter error."""
raise ValueError(explanation)
user_embs = np.concatenate(session_embs, axis=0)
user_embs = user_embs[user_ids]
Expand Down
5 changes: 3 additions & 2 deletions rectools/models/nn/transformer_data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class SessionEncoderDataPreparatorBase:
Function to get validation mask.
"""

train_session_max_len_addition: int = 0

def __init__(
self,
session_max_len: int,
Expand All @@ -116,7 +118,6 @@ def __init__(
get_val_mask_func: tp.Optional[tp.Callable] = None,
**kwargs: tp.Any,
) -> None:
"""TODO"""
self.item_id_map: IdMap
self.extra_token_ids: tp.Dict
self.train_dataset: Dataset
Expand Down Expand Up @@ -160,7 +161,7 @@ def process_dataset_train(self, dataset: Dataset) -> None:
interactions = (
interactions.sort_values(Columns.Datetime, kind="stable")
.groupby(Columns.User, sort=False)
.tail(self.session_max_len + 1)
.tail(self.session_max_len + self.train_session_max_len_addition)
)

# Construct dataset
Expand Down
Loading

0 comments on commit 24a7877

Please sign in to comment.