Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sasrec configs #248

Merged
merged 28 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 248 additions & 1 deletion examples/9_model_configs_and_saving.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,23 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/lightfm/_lightfm_fast.py:9: UserWarning: LightFM was compiled without OpenMP support. Only a single thread will be used.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from datetime import timedelta\n",
"import pandas as pd\n",
"\n",
"from rectools.models import (\n",
" SASRecModel,\n",
" BERT4RecModel,\n",
" ImplicitItemKNNWrapperModel, \n",
" ImplicitALSWrapperModel, \n",
" ImplicitBPRWrapperModel, \n",
Expand Down Expand Up @@ -315,6 +327,241 @@
"## Configs examples for all models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SASRec"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
},
{
"data": {
"text/plain": [
"{'cls': 'SASRecModel',\n",
" 'verbose': 0,\n",
" 'data_preparator_type': 'rectools.models.nn.sasrec.SASRecDataPreparator',\n",
" 'n_blocks': 1,\n",
" 'n_heads': 1,\n",
" 'n_factors': 64,\n",
" 'use_pos_emb': True,\n",
" 'use_causal_attn': True,\n",
" 'use_key_padding_mask': False,\n",
" 'dropout_rate': 0.2,\n",
" 'session_max_len': 100,\n",
" 'dataloader_num_workers': 0,\n",
" 'batch_size': 128,\n",
" 'loss': 'softmax',\n",
" 'n_negatives': 1,\n",
" 'gbce_t': 0.2,\n",
" 'lr': 0.001,\n",
" 'epochs': 2,\n",
" 'deterministic': False,\n",
" 'recommend_batch_size': 256,\n",
" 'recommend_accelerator': 'auto',\n",
" 'recommend_devices': 1,\n",
" 'recommend_n_threads': 0,\n",
" 'recommend_use_gpu_ranking': True,\n",
" 'train_min_user_interactions': 2,\n",
" 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n",
" 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n",
" 'pos_encoding_type': 'rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding',\n",
" 'transformer_layers_type': 'rectools.models.nn.sasrec.SASRecTransformerLayers',\n",
" 'lightning_module_type': 'rectools.models.nn.transformer_base.SessionEncoderLightningModule',\n",
" 'get_val_mask_func': None}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = SASRecModel.from_config({\n",
feldlime marked this conversation as resolved.
Show resolved Hide resolved
" \"epochs\": 2,\n",
" \"n_blocks\": 1,\n",
" \"n_heads\": 1,\n",
" \"n_factors\": 64,\n",
"})\n",
"model.get_params(simple_types=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Transformer models (SASRec and BERT4Rec) in RecTools may accept functions and classes as arguments. These types of arguments are fully compatible with RecTools configs. You can eigther pass them as python objects or as strings that define their import paths.\n",
"\n",
"Below is an example of both approaches:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n",
"/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/pydantic/main.py:426: UserWarning: Pydantic serializer warnings:\n",
" Expected `str` but got `tuple` with value `('rectools.models.nn.item...net.CatFeaturesItemNet')` - serialized value may not be as expected\n",
" return self.__pydantic_serializer__.to_python(\n"
]
},
{
"data": {
"text/plain": [
"{'cls': 'SASRecModel',\n",
" 'verbose': 0,\n",
" 'data_preparator_type': 'rectools.models.nn.sasrec.SASRecDataPreparator',\n",
" 'n_blocks': 2,\n",
" 'n_heads': 4,\n",
" 'n_factors': 256,\n",
" 'use_pos_emb': True,\n",
" 'use_causal_attn': True,\n",
" 'use_key_padding_mask': False,\n",
" 'dropout_rate': 0.2,\n",
" 'session_max_len': 100,\n",
" 'dataloader_num_workers': 0,\n",
" 'batch_size': 128,\n",
" 'loss': 'softmax',\n",
" 'n_negatives': 1,\n",
" 'gbce_t': 0.2,\n",
" 'lr': 0.001,\n",
" 'epochs': 3,\n",
" 'deterministic': False,\n",
" 'recommend_batch_size': 256,\n",
" 'recommend_accelerator': 'auto',\n",
" 'recommend_devices': 1,\n",
" 'recommend_n_threads': 0,\n",
" 'recommend_use_gpu_ranking': True,\n",
" 'train_min_user_interactions': 2,\n",
" 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n",
" 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n",
" 'pos_encoding_type': 'rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding',\n",
" 'transformer_layers_type': 'rectools.models.nn.sasrec.SASRecTransformerLayers',\n",
" 'lightning_module_type': 'rectools.models.nn.transformer_base.SessionEncoderLightningModule',\n",
" 'get_val_mask_func': '__main__.leave_one_out_mask'}"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:\n",
" rank = (\n",
" interactions\n",
" .sort_values(Columns.Datetime, ascending=False, kind=\"stable\")\n",
" .groupby(Columns.User, sort=False)\n",
" .cumcount()\n",
" )\n",
" return rank == 0\n",
"\n",
"model = SASRecModel.from_config({\n",
" \"get_val_mask_func\": leave_one_out_mask, # function to get validation mask\n",
" \"transformer_layers_type\": \"rectools.models.nn.sasrec.SASRecTransformerLayers\", # path to transformer layers class\n",
"})\n",
"model.get_params(simple_types=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### BERT4Rec"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n",
"/Users/dmtikhonov/git_project/metrics/RecTools/.venv/lib/python3.10/site-packages/pydantic/main.py:426: UserWarning: Pydantic serializer warnings:\n",
" Expected `str` but got `tuple` with value `('rectools.models.nn.item...net.CatFeaturesItemNet')` - serialized value may not be as expected\n",
" return self.__pydantic_serializer__.to_python(\n"
]
},
{
"data": {
"text/plain": [
"{'cls': 'BERT4RecModel',\n",
" 'verbose': 0,\n",
" 'data_preparator_type': 'rectools.models.nn.bert4rec.BERT4RecDataPreparator',\n",
" 'n_blocks': 1,\n",
" 'n_heads': 1,\n",
" 'n_factors': 64,\n",
" 'use_pos_emb': True,\n",
" 'use_causal_attn': False,\n",
" 'use_key_padding_mask': True,\n",
" 'dropout_rate': 0.2,\n",
" 'session_max_len': 100,\n",
" 'dataloader_num_workers': 0,\n",
" 'batch_size': 128,\n",
" 'loss': 'softmax',\n",
" 'n_negatives': 1,\n",
" 'gbce_t': 0.2,\n",
" 'lr': 0.001,\n",
" 'epochs': 2,\n",
" 'deterministic': False,\n",
" 'recommend_batch_size': 256,\n",
" 'recommend_accelerator': 'auto',\n",
" 'recommend_devices': 1,\n",
" 'recommend_n_threads': 0,\n",
" 'recommend_use_gpu_ranking': True,\n",
" 'train_min_user_interactions': 2,\n",
" 'item_net_block_types': ['rectools.models.nn.item_net.IdEmbeddingsItemNet',\n",
" 'rectools.models.nn.item_net.CatFeaturesItemNet'],\n",
" 'pos_encoding_type': 'rectools.models.nn.transformer_net_blocks.LearnableInversePositionalEncoding',\n",
" 'transformer_layers_type': 'rectools.models.nn.sasrec.SASRecTransformerLayers',\n",
" 'lightning_module_type': 'rectools.models.nn.transformer_base.SessionEncoderLightningModule',\n",
" 'get_val_mask_func': '__main__.leave_one_out_mask',\n",
" 'mask_prob': 0.2}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = BERT4RecModel.from_config({\n",
" \"epochs\": 2,\n",
" \"n_blocks\": 1,\n",
" \"n_heads\": 1,\n",
" \"n_factors\": 64,\n",
" \"mask_prob\": 0.2,\n",
" \"get_val_mask_func\": leave_one_out_mask, # function to get validation mask\n",
" \"transformer_layers_type\": \"rectools.models.nn.sasrec.SASRecTransformerLayers\", # path to transformer layers class\n",
"})\n",
"model.get_params(simple_types=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
12 changes: 12 additions & 0 deletions rectools/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ class DSSMModel(RequirementUnavailable):
requirement = "torch"


class SASRecModel(RequirementUnavailable):
"""Dummy class, which is returned if there are no dependencies required for the model"""

requirement = "torch"


class BERT4RecModel(RequirementUnavailable):
"""Dummy class, which is returned if there are no dependencies required for the model"""

requirement = "torch"


class ItemToItemAnnRecommender(RequirementUnavailable):
"""Dummy class, which is returned if there are no dependencies required for the model"""

Expand Down
Loading