diff --git a/examples/bert4rec.ipynb b/examples/bert4rec.ipynb deleted file mode 100644 index 9bd18521..00000000 --- a/examples/bert4rec.ipynb +++ /dev/null @@ -1,587 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import torch\n", - "import threadpoolctl\n", - "from pathlib import Path\n", - "from lightning_fabric import seed_everything\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "from rectools import Columns\n", - "\n", - "\n", - "from rectools.dataset import Dataset\n", - "from rectools.metrics import MAP, calc_metrics, MeanInvUserFreq, Serendipity\n", - "from rectools.models import BERT4RecModel\n", - "from rectools.models.nn.item_net import IdEmbeddingsItemNet" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", - "os.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\n", - "threadpoolctl.threadpool_limits(1, \"blas\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Prepare data" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# %%time\n", - "# !wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_original.zip\n", - "# !unzip -o data_original.zip\n", - "# !rm data_original.zip" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "DATA_PATH = Path(\"data_original\")\n", - "\n", - "interactions = (\n", - " pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=[\"last_watch_dt\"])\n", - " .rename(columns={\"last_watch_dt\": \"datetime\"})\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)\n", - "\n", - "# Split to train / test\n", - "max_date = interactions[Columns.Datetime].max()\n", - "train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()\n", - "test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()\n", - "train.drop(train.query(\"total_dur < 300\").index, inplace=True)\n", - "\n", - "# drop items with less than 20 interactions in train\n", - "items = train[\"item_id\"].value_counts()\n", - "items = items[items >= 20]\n", - "items = items.index.to_list()\n", - "train = train[train[\"item_id\"].isin(items)]\n", - " \n", - "# drop users with less than 2 interactions in train\n", - "users = train[\"user_id\"].value_counts()\n", - "users = users[users >= 2]\n", - "users = users.index.to_list()\n", - "train = train[(train[\"user_id\"].isin(users))]\n", - "\n", - "users = train[\"user_id\"].drop_duplicates().to_list()\n", - "\n", - "# drop cold users from test\n", - "test_users_sasrec = test[Columns.User].unique()\n", - "cold_users = set(test[Columns.User]) - set(train[Columns.User])\n", - "test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)\n", - "test_users = test[Columns.User].unique()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "items = pd.read_csv(DATA_PATH / 'items.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# Process item features to the form of a flatten dataframe\n", - "items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()\n", - "items[\"genre\"] = items[\"genres\"].str.lower().str.replace(\", \", \",\", regex=False).str.split(\",\")\n", - "genre_feature = items[[\"item_id\", \"genre\"]].explode(\"genre\")\n", - "genre_feature.columns = [\"id\", \"value\"]\n", - "genre_feature[\"feature\"] = \"genre\"\n", - "content_feature = items.reindex(columns=[Columns.Item, \"content_type\"])\n", - "content_feature.columns = [\"id\", \"value\"]\n", - "content_feature[\"feature\"] = \"content_type\"\n", - "item_features = pd.concat((genre_feature, content_feature))\n", - "\n", - "candidate_items = interactions['item_id'].drop_duplicates().astype(int)\n", - "test[\"user_id\"] = test[\"user_id\"].astype(int)\n", - "test[\"item_id\"] = test[\"item_id\"].astype(int)\n", - "\n", - "catalog=train[Columns.Item].unique()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "dataset_no_features = Dataset.construct(\n", - " interactions_df=train,\n", - ")\n", - "\n", - "dataset_item_features = Dataset.construct(\n", - " interactions_df=train,\n", - " item_features_df=item_features,\n", - " cat_item_features=[\"genre\", \"content_type\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "metrics_name = {\n", - " 'MAP': MAP,\n", - " 'MIUF': MeanInvUserFreq,\n", - " 'Serendipity': Serendipity\n", - " \n", - "\n", - "}\n", - "metrics = {}\n", - "for metric_name, metric in metrics_name.items():\n", - " for k in (1, 5, 10):\n", - " metrics[f'{metric_name}@{k}'] = metric(k=k)\n", - "\n", - "# list with metrics results of all models\n", - "features_results = []\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# BERT4Rec" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 32\n" - ] - }, - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "RANDOM_SEED = 32\n", - "torch.use_deterministic_algorithms(True)\n", - "seed_everything(RANDOM_SEED, workers=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### BERT4Rec with item ids embeddings in ItemNetBlock" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "model = BERT4RecModel(\n", - " n_blocks=3,\n", - " n_heads=4,\n", - " dropout_rate=0.2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " mask_prob=0.5,\n", - " deterministic=True,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 1.3 M \n", - "---------------------------------------------------------------\n", - "1.3 M Trainable params\n", - "0 Non-trainable params\n", - "1.3 M Total params\n", - "5.292 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fee32d8cfa0147689a003b073d52fbd5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "af7987f2168242baa2476d3b6c955440", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
07344697280.6417741
17344677930.4246962
27344678290.3713493
373446143170.3570834
47344631820.0686615
...............
94704585716237341.2114546
94704685716286361.1050207
94704785716241511.0535188
947048857162147030.8832019
947049857162129950.82724610
\n", - "

947050 rows × 4 columns

\n", - "" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 73446 9728 0.641774 1\n", - "1 73446 7793 0.424696 2\n", - "2 73446 7829 0.371349 3\n", - "3 73446 14317 0.357083 4\n", - "4 73446 3182 0.068661 5\n", - "... ... ... ... ...\n", - "947045 857162 3734 1.211454 6\n", - "947046 857162 8636 1.105020 7\n", - "947047 857162 4151 1.053518 8\n", - "947048 857162 14703 0.883201 9\n", - "947049 857162 12995 0.827246 10\n", - "\n", - "[947050 rows x 4 columns]" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recos" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'MAP@1': 0.04615447203061492,\n", - " 'MAP@5': 0.07738831584888614,\n", - " 'MAP@10': 0.08574348640312766,\n", - " 'MIUF@1': 3.904076196457114,\n", - " 'MIUF@5': 4.52025063365768,\n", - " 'MIUF@10': 5.012999210434636,\n", - " 'Serendipity@1': 0.0004970568200398246,\n", - " 'Serendipity@5': 0.000480108457862388,\n", - " 'Serendipity@10': 0.00045721034938317576,\n", - " 'model': 'bert4rec_ids'}]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "features_results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "rectools_origin", - "language": "python", - "name": "rectools_origin" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/sasrec_metrics_comp.ipynb b/examples/sasrec_metrics_comp.ipynb deleted file mode 100644 index 1e7df0e4..00000000 --- a/examples/sasrec_metrics_comp.ipynb +++ /dev/null @@ -1,2647 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import logging\n", - "import os\n", - "import threadpoolctl\n", - "import torch\n", - "from pathlib import Path\n", - "from lightning_fabric import seed_everything\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "from rectools import Columns\n", - "\n", - "from implicit.als import AlternatingLeastSquares\n", - "\n", - "from rectools.dataset import Dataset\n", - "from rectools.metrics import MAP, calc_metrics, MeanInvUserFreq, Serendipity\n", - "from rectools.models import ImplicitALSWrapperModel\n", - "from rectools.models import SASRecModel\n", - "from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", - "\n", - "# For implicit ALS\n", - "os.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\n", - "threadpoolctl.threadpool_limits(1, \"blas\")\n", - "\n", - "logging.basicConfig()\n", - "logging.getLogger().setLevel(logging.INFO)\n", - "logger = logging.getLogger()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Data" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# %%time\n", - "# !wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_original.zip\n", - "# !unzip -o data_original.zip\n", - "# !rm data_original.zip" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "DATA_PATH = Path(\"data_original\")\n", - "\n", - "interactions = (\n", - " pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=[\"last_watch_dt\"])\n", - " .rename(columns={\"last_watch_dt\": \"datetime\"})\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Split dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)\n", - "\n", - "# Split to train / test\n", - "max_date = interactions[Columns.Datetime].max()\n", - "train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()\n", - "test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()\n", - "train.drop(train.query(\"total_dur < 300\").index, inplace=True)\n", - "\n", - "# drop items with less than 20 interactions in train\n", - "items = train[\"item_id\"].value_counts()\n", - "items = items[items >= 20]\n", - "items = items.index.to_list()\n", - "train = train[train[\"item_id\"].isin(items)]\n", - " \n", - "# drop users with less than 2 interactions in train\n", - "users = train[\"user_id\"].value_counts()\n", - "users = users[users >= 2]\n", - "users = users.index.to_list()\n", - "train = train[(train[\"user_id\"].isin(users))]\n", - "\n", - "users = train[\"user_id\"].drop_duplicates().to_list()\n", - "\n", - "# drop cold users from test\n", - "test_users_sasrec = test[Columns.User].unique()\n", - "cold_users = set(test[Columns.User]) - set(train[Columns.User])\n", - "test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)\n", - "test_users = test[Columns.User].unique()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "items = pd.read_csv(DATA_PATH / 'items.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# Process item features to the form of a flatten dataframe\n", - "items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()\n", - "items[\"genre\"] = items[\"genres\"].str.lower().str.replace(\", \", \",\", regex=False).str.split(\",\")\n", - "genre_feature = items[[\"item_id\", \"genre\"]].explode(\"genre\")\n", - "genre_feature.columns = [\"id\", \"value\"]\n", - "genre_feature[\"feature\"] = \"genre\"\n", - "content_feature = items.reindex(columns=[Columns.Item, \"content_type\"])\n", - "content_feature.columns = [\"id\", \"value\"]\n", - "content_feature[\"feature\"] = \"content_type\"\n", - "item_features = pd.concat((genre_feature, content_feature))\n", - "\n", - "candidate_items = interactions['item_id'].drop_duplicates().astype(int)\n", - "test[\"user_id\"] = test[\"user_id\"].astype(int)\n", - "test[\"item_id\"] = test[\"item_id\"].astype(int)\n", - "\n", - "catalog=train[Columns.Item].unique()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "dataset_no_features = Dataset.construct(\n", - " interactions_df=train,\n", - ")\n", - "\n", - "dataset_item_features = Dataset.construct(\n", - " interactions_df=train,\n", - " item_features_df=item_features,\n", - " cat_item_features=[\"genre\", \"content_type\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "metrics_name = {\n", - " 'MAP': MAP,\n", - " 'MIUF': MeanInvUserFreq,\n", - " 'Serendipity': Serendipity\n", - " \n", - "\n", - "}\n", - "metrics = {}\n", - "for metric_name, metric in metrics_name.items():\n", - " for k in (1, 5, 10):\n", - " metrics[f'{metric_name}@{k}'] = metric(k=k)\n", - "\n", - "# list with metrics results of all models\n", - "features_results = []" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# SASRec" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 32\n" - ] - }, - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "RANDOM_SEED = 32\n", - "torch.use_deterministic_algorithms(True)\n", - "seed_everything(RANDOM_SEED, workers=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Softmax loss" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "model = SASRecModel(\n", - " n_blocks=2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " deterministic=True,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - " recommend_device=\"cuda\",\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 927 K \n", - "---------------------------------------------------------------\n", - "927 K Trainable params\n", - "0 Non-trainable params\n", - "927 K Total params\n", - "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3040ffdf32ad4a8a9d7d55de6524f3ad", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e00badfab5564eb48c4bcde5aeae378a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
07344697282.0736191
17344677932.0322062
27344637841.8622593
373446129651.8219244
47344631821.7608295
...............
94704585716268092.0551566
947046857162129951.8115437
947047857162153621.5247448
94704885716244951.4702609
947049857162471.44381410
\n", - "

947050 rows × 4 columns

\n", - "" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 73446 9728 2.073619 1\n", - "1 73446 7793 2.032206 2\n", - "2 73446 3784 1.862259 3\n", - "3 73446 12965 1.821924 4\n", - "4 73446 3182 1.760829 5\n", - "... ... ... ... ...\n", - "947045 857162 6809 2.055156 6\n", - "947046 857162 12995 1.811543 7\n", - "947047 857162 15362 1.524744 8\n", - "947048 857162 4495 1.470260 9\n", - "947049 857162 47 1.443814 10\n", - "\n", - "[947050 rows x 4 columns]" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recos" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## BCE loss" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 32\n" - ] - }, - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "RANDOM_SEED = 32\n", - "torch.use_deterministic_algorithms(True)\n", - "seed_everything(RANDOM_SEED, workers=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "model = SASRecModel(\n", - " n_blocks=2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " deterministic=True,\n", - " loss=\"BCE\",\n", - " n_negatives=2,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 927 K \n", - "---------------------------------------------------------------\n", - "927 K Trainable params\n", - "0 Non-trainable params\n", - "927 K Total params\n", - "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0ef5450c14a8497983690b871b6f2a34", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9192a96923c24843b92e1c7e2ab184d7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
07344697283.9052391
17344678293.6750182
27344677933.6065473
37344637843.5057954
47344666463.3144925
...............
94704585716237343.1205926
947046857162129952.9345467
947047857162126922.7310198
9470488571626572.5968909
947049857162153622.59107210
\n", - "

947050 rows × 4 columns

\n", - "" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 73446 9728 3.905239 1\n", - "1 73446 7829 3.675018 2\n", - "2 73446 7793 3.606547 3\n", - "3 73446 3784 3.505795 4\n", - "4 73446 6646 3.314492 5\n", - "... ... ... ... ...\n", - "947045 857162 3734 3.120592 6\n", - "947046 857162 12995 2.934546 7\n", - "947047 857162 12692 2.731019 8\n", - "947048 857162 657 2.596890 9\n", - "947049 857162 15362 2.591072 10\n", - "\n", - "[947050 rows x 4 columns]" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recos" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## gBCE loss" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 32\n" - ] - }, - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "RANDOM_SEED = 32\n", - "torch.use_deterministic_algorithms(True)\n", - "seed_everything(RANDOM_SEED, workers=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "model = SASRecModel(\n", - " n_blocks=2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " deterministic=True,\n", - " loss=\"gBCE\",\n", - " n_negatives=256,\n", - " gbce_t=0.75,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 927 K \n", - "---------------------------------------------------------------\n", - "927 K Trainable params\n", - "0 Non-trainable params\n", - "927 K Total params\n", - "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d493a5caa5f74735a8916101d422f022", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9d9a87779d924688b7498005aa3dadc9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "bdf51affb5944bc1a4630b17aefbc68f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
MAP@1MAP@5MAP@10MIUF@1MIUF@5MIUF@10Serendipity@1Serendipity@5Serendipity@10
model
softmax0.0484880.0821410.0912583.8686774.5062425.0878120.0010590.0008280.000751
softmax_padding_mask0.0469030.0808400.0899524.0448134.6015065.0919480.0010640.0008220.000720
gBCE0.0461730.0803170.0891863.1867943.8226234.5387960.0006400.0005170.000505
bce0.0430940.0739840.0829443.5983524.3040134.9160070.0004770.0004860.000502
\n", - "" - ], - "text/plain": [ - " MAP@1 MAP@5 MAP@10 MIUF@1 MIUF@5 \\\n", - "model \n", - "softmax 0.048488 0.082141 0.091258 3.868677 4.506242 \n", - "softmax_padding_mask 0.046903 0.080840 0.089952 4.044813 4.601506 \n", - "gBCE 0.046173 0.080317 0.089186 3.186794 3.822623 \n", - "bce 0.043094 0.073984 0.082944 3.598352 4.304013 \n", - "\n", - " MIUF@10 Serendipity@1 Serendipity@5 Serendipity@10 \n", - "model \n", - "softmax 5.087812 0.001059 0.000828 0.000751 \n", - "softmax_padding_mask 5.091948 0.001064 0.000822 0.000720 \n", - "gBCE 4.538796 0.000640 0.000517 0.000505 \n", - "bce 4.916007 0.000477 0.000486 0.000502 " - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "features_df = (\n", - " pd.DataFrame(features_results)\n", - " .set_index(\"model\")\n", - " .sort_values(by=[\"MAP@10\", \"Serendipity@10\"], ascending=False)\n", - ")\n", - "features_df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### sasrec with item ids embeddings in ItemNetBlock" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [], - "source": [ - "model = SASRecModel(\n", - " n_blocks=2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " deterministic=True,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 927 K \n", - "---------------------------------------------------------------\n", - "927 K Trainable params\n", - "0 Non-trainable params\n", - "927 K Total params\n", - "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9f638679084f411ba50da747f7635338", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "18af83b15a1041de92fab2cdb62379ab", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#%%time\n", - "model.fit(dataset_item_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b66c4dbd535f4601af3e8fb91b0de523", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#%%time\n", - "model.fit(dataset_item_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "bddb0ea860994c5e9eafb49b598756a0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
target_item_iditem_idscorerank
013865156481.0000001
11386533861.0000002
213865161940.9087693
3138651470.9087694
413865125860.9087695
513865123090.9087696
61386566610.9087697
71386522550.9087698
81386541300.9087699
91386591090.90876910
10445751091.0000001
11445788511.0000002
12445784861.0000003
134457120871.0000004
14445723131.0000005
154457119771.0000006
16445733841.0000007
17445762851.0000008
18445779281.0000009
194457115131.00000010
201529787231.0000001
211529759261.0000002
221529741311.0000003
231529742291.0000004
241529770051.0000005
2515297107971.0000006
2615297105351.0000007
271529754001.0000008
281529747161.0000009
2915297131031.00000010
\n", - "" - ], - "text/plain": [ - " target_item_id item_id score rank\n", - "0 13865 15648 1.000000 1\n", - "1 13865 3386 1.000000 2\n", - "2 13865 16194 0.908769 3\n", - "3 13865 147 0.908769 4\n", - "4 13865 12586 0.908769 5\n", - "5 13865 12309 0.908769 6\n", - "6 13865 6661 0.908769 7\n", - "7 13865 2255 0.908769 8\n", - "8 13865 4130 0.908769 9\n", - "9 13865 9109 0.908769 10\n", - "10 4457 5109 1.000000 1\n", - "11 4457 8851 1.000000 2\n", - "12 4457 8486 1.000000 3\n", - "13 4457 12087 1.000000 4\n", - "14 4457 2313 1.000000 5\n", - "15 4457 11977 1.000000 6\n", - "16 4457 3384 1.000000 7\n", - "17 4457 6285 1.000000 8\n", - "18 4457 7928 1.000000 9\n", - "19 4457 11513 1.000000 10\n", - "20 15297 8723 1.000000 1\n", - "21 15297 5926 1.000000 2\n", - "22 15297 4131 1.000000 3\n", - "23 15297 4229 1.000000 4\n", - "24 15297 7005 1.000000 5\n", - "25 15297 10797 1.000000 6\n", - "26 15297 10535 1.000000 7\n", - "27 15297 5400 1.000000 8\n", - "28 15297 4716 1.000000 9\n", - "29 15297 13103 1.000000 10" - ] - }, - "execution_count": 56, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recos" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
target_item_iditem_idscoreranktitle
013865156481.0000001Черное золото
11386533861.0000002Спартак
213865161940.9087693Голубая линия
3138651470.9087694Единичка
413865125860.9087695Вспоминая 1942
513865123090.9087696Враг у ворот
61386566610.9087697Солдатик
71386522550.9087698Пленный
81386541300.9087699Пустота
91386591090.90876910Последняя битва
10445751091.0000001Время разлуки
11445788511.0000002Лисы
12445784861.0000003Мой создатель
134457120871.0000004Молчаливое бегство
14445723131.0000005Свет моей жизни
154457119771.0000006Зоология
16445733841.0000007Вивариум
17445762851.0000008Божественная любовь
18445779281.0000009Вечная жизнь
194457115131.00000010Любовь
201529787231.0000001Секс в другом городе: Поколение Q
211529759261.0000002Пациенты
221529741311.0000003Учитель Ким, доктор Романтик
231529742291.0000004Самара
241529770051.0000005Чёрная кровь
2515297107971.0000006Наследники
2615297105351.0000007Я могу уничтожить тебя
271529754001.0000008Частица вселенной
281529747161.0000009Мастера секса
2915297131031.00000010Хороший доктор
\n", - "
" - ], - "text/plain": [ - " target_item_id item_id score rank title\n", - "0 13865 15648 1.000000 1 Черное золото\n", - "1 13865 3386 1.000000 2 Спартак\n", - "2 13865 16194 0.908769 3 Голубая линия\n", - "3 13865 147 0.908769 4 Единичка\n", - "4 13865 12586 0.908769 5 Вспоминая 1942\n", - "5 13865 12309 0.908769 6 Враг у ворот\n", - "6 13865 6661 0.908769 7 Солдатик\n", - "7 13865 2255 0.908769 8 Пленный\n", - "8 13865 4130 0.908769 9 Пустота\n", - "9 13865 9109 0.908769 10 Последняя битва\n", - "10 4457 5109 1.000000 1 Время разлуки\n", - "11 4457 8851 1.000000 2 Лисы\n", - "12 4457 8486 1.000000 3 Мой создатель\n", - "13 4457 12087 1.000000 4 Молчаливое бегство\n", - "14 4457 2313 1.000000 5 Свет моей жизни\n", - "15 4457 11977 1.000000 6 Зоология\n", - "16 4457 3384 1.000000 7 Вивариум\n", - "17 4457 6285 1.000000 8 Божественная любовь\n", - "18 4457 7928 1.000000 9 Вечная жизнь\n", - "19 4457 11513 1.000000 10 Любовь\n", - "20 15297 8723 1.000000 1 Секс в другом городе: Поколение Q\n", - "21 15297 5926 1.000000 2 Пациенты\n", - "22 15297 4131 1.000000 3 Учитель Ким, доктор Романтик\n", - "23 15297 4229 1.000000 4 Самара\n", - "24 15297 7005 1.000000 5 Чёрная кровь\n", - "25 15297 10797 1.000000 6 Наследники\n", - "26 15297 10535 1.000000 7 Я могу уничтожить тебя\n", - "27 15297 5400 1.000000 8 Частица вселенной\n", - "28 15297 4716 1.000000 9 Мастера секса\n", - "29 15297 13103 1.000000 10 Хороший доктор" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# TODO: change model for recos (here is the last one trained and is is the worst in quality)\n", - "recos.merge(items[[\"item_id\", \"title\"]], on=\"item_id\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "rectools_origin", - "language": "python", - "name": "rectools_origin" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index 5a20525c..d215f7c6 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -36,7 +36,9 @@ def from_dataset(cls, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> tp.O raise NotImplementedError() @classmethod - def from_dataset_schema(cls, dataset_schema: DatasetSchema, *args: tp.Any, **kwargs: tp.Any) -> tpe.Self: + def from_dataset_schema( + cls, dataset_schema: DatasetSchema, *args: tp.Any, **kwargs: tp.Any + ) -> tp.Optional[tpe.Self]: """Construct ItemNet from Dataset schema.""" raise NotImplementedError() @@ -66,19 +68,23 @@ class CatFeaturesItemNet(ItemNetBase): def __init__( self, - item_features: SparseFeatures, + emb_bag_inputs: torch.Tensor, + input_lengths: torch.Tensor, + offsets: torch.Tensor, + n_cat_feature_values: int, n_factors: int, dropout_rate: float, ): super().__init__() - self.item_features = item_features - self.n_items = len(item_features) - self.n_cat_features = len(item_features.names) - - self.category_embeddings = nn.Embedding(num_embeddings=self.n_cat_features, embedding_dim=n_factors) + self.n_cat_feature_values = n_cat_feature_values + self.embedding_bag = nn.EmbeddingBag(num_embeddings=n_cat_feature_values, embedding_dim=n_factors, mode="sum") self.drop_layer = nn.Dropout(dropout_rate) + self.register_buffer("offsets", offsets) + self.register_buffer("emb_bag_inputs", emb_bag_inputs) + self.register_buffer("input_lengths", input_lengths) + def forward(self, items: torch.Tensor) -> torch.Tensor: """ Forward pass to get item embeddings from categorical item features. @@ -93,35 +99,21 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: torch.Tensor Item embeddings. """ - feature_dense = self.get_dense_item_features(items) - - feature_embs = self.category_embeddings(self.feature_catalog.to(self.device)) - feature_embs = self.drop_layer(feature_embs) - - feature_embeddings_per_items = feature_dense.to(self.device) @ feature_embs + item_emb_bag_inputs, item_offsets = self.get_item_inputs_offsets(items) + feature_embeddings_per_items = self.embedding_bag(input=item_emb_bag_inputs, offsets=item_offsets) + feature_embeddings_per_items = self.drop_layer(feature_embeddings_per_items) return feature_embeddings_per_items - @property - def feature_catalog(self) -> torch.Tensor: - """Return tensor with elements in range [0, n_cat_features).""" - return torch.arange(0, self.n_cat_features) - - def get_dense_item_features(self, items: torch.Tensor) -> torch.Tensor: - """ - Get categorical item values by certain item ids in dense format. - - Parameters - ---------- - items : torch.Tensor - Internal item ids. - - Returns - ------- - torch.Tensor - categorical item values in dense format. - """ - feature_dense = self.item_features.take(items.detach().cpu().numpy()).get_dense() - return torch.from_numpy(feature_dense) + def get_item_inputs_offsets(self, items: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Get categorical item features and offsets for `items`.""" + length_range = torch.arange(self.input_lengths.max().item(), device=self.device) + item_indexes = self.offsets[items].unsqueeze(-1) + length_range + length_mask = length_range < self.input_lengths[items].unsqueeze(-1) + item_emb_bag_inputs = self.emb_bag_inputs[item_indexes[length_mask]] + item_offsets = torch.cat( + (torch.tensor([0], device=self.device), torch.cumsum(self.input_lengths[items], dim=0)[:-1]) + ) + return item_emb_bag_inputs, item_offsets @classmethod def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> tp.Optional[tpe.Self]: @@ -161,7 +153,59 @@ def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> warnings.warn(explanation) return None - return cls(item_cat_features, n_factors, dropout_rate) + emb_bag_inputs = torch.tensor(item_cat_features.values.indices, dtype=torch.long) + offsets = torch.tensor(item_cat_features.values.indptr, dtype=torch.long) + input_lengths = torch.diff(offsets, dim=0) + n_cat_feature_values = len(item_cat_features.names) + + return cls( + emb_bag_inputs=emb_bag_inputs, + offsets=offsets[:-1], + input_lengths=input_lengths, + n_cat_feature_values=n_cat_feature_values, + n_factors=n_factors, + dropout_rate=dropout_rate, + ) + + @classmethod + def from_dataset_schema( + cls, dataset_schema: DatasetSchema, n_factors: int, dropout_rate: float + ) -> tp.Optional[tpe.Self]: + """Construct CatFeaturesItemNet from Dataset schema.""" + if dataset_schema.items.features is None: + explanation = """Ignoring `CatFeaturesItemNet` block because dataset doesn't contain item features.""" + warnings.warn(explanation) + return None + + if dataset_schema.items.features.kind == "dense": + explanation = """ + Ignoring `CatFeaturesItemNet` block because + dataset item features are dense and unable to contain categorical features. + """ + warnings.warn(explanation) + return None + + if len(dataset_schema.items.features.cat_feature_indices) == 0: + explanation = """ + Ignoring `CatFeaturesItemNet` block because dataset item features do not contain categorical features. + """ + warnings.warn(explanation) + return None + + emb_bag_inputs = torch.randint( + high=dataset_schema.items.n_hot, size=(dataset_schema.items.features.cat_n_stored_values,) + ) + offsets = torch.randint(high=dataset_schema.items.n_hot, size=(dataset_schema.items.n_hot,)) + input_lengths = torch.randint(high=dataset_schema.items.n_hot, size=(dataset_schema.items.n_hot,)) + n_cat_feature_values = len(dataset_schema.items.features.cat_feature_indices) + return cls( + emb_bag_inputs=emb_bag_inputs, + offsets=offsets, + input_lengths=input_lengths, + n_cat_feature_values=n_cat_feature_values, + n_factors=n_factors, + dropout_rate=dropout_rate, + ) class IdEmbeddingsItemNet(ItemNetBase): diff --git a/tests/models/nn/test_item_net.py b/tests/models/nn/test_item_net.py index bc861e71..69f641ea 100644 --- a/tests/models/nn/test_item_net.py +++ b/tests/models/nn/test_item_net.py @@ -22,7 +22,6 @@ from rectools.columns import Columns from rectools.dataset import Dataset -from rectools.dataset.features import SparseFeatures from rectools.models.nn.item_net import ( CatFeaturesItemNet, IdEmbeddingsItemNet, @@ -30,7 +29,6 @@ ItemNetConstructorBase, SumOfEmbeddingsConstructor, ) -from tests.testing_utils import assert_feature_set_equal from ..data import DATASET, INTERACTIONS @@ -75,8 +73,6 @@ def _seed_everything(self) -> None: def dataset_item_features(self) -> Dataset: item_features = pd.DataFrame( [ - [11, "f1", "f1val1"], - [11, "f2", "f2val1"], [12, "f1", "f1val1"], [12, "f2", "f2val2"], [13, "f1", "f1val1"], @@ -89,7 +85,6 @@ def dataset_item_features(self) -> Dataset: [17, "f2", "f2val3"], [16, "f1", "f1val2"], [16, "f2", "f2val3"], - [11, "f3", 0], [12, "f3", 1], [13, "f3", 2], [14, "f3", 3], @@ -106,34 +101,19 @@ def dataset_item_features(self) -> Dataset: ) return ds - def test_feature_catalog(self, dataset_item_features: Dataset) -> None: - cat_item_embeddings = CatFeaturesItemNet.from_dataset(dataset_item_features, n_factors=5, dropout_rate=0.5) - assert isinstance(cat_item_embeddings, CatFeaturesItemNet) - expected_feature_catalog = torch.arange(0, cat_item_embeddings.n_cat_features) - assert torch.equal(cat_item_embeddings.feature_catalog, expected_feature_catalog) - - def test_get_dense_item_features(self, dataset_item_features: Dataset) -> None: + def test_get_item_inputs_offsets(self, dataset_item_features: Dataset) -> None: items = torch.from_numpy( dataset_item_features.item_id_map.convert_to_internal(INTERACTIONS[Columns.Item].unique()) - ) + )[:-1] cat_item_embeddings = CatFeaturesItemNet.from_dataset(dataset_item_features, n_factors=5, dropout_rate=0.5) assert isinstance(cat_item_embeddings, CatFeaturesItemNet) - actual_feature_dense = cat_item_embeddings.get_dense_item_features(items) - expected_feature_dense = torch.tensor( - [ - [1, 0, 1, 0, 0], - [1, 0, 0, 1, 0], - [0, 1, 1, 0, 0], - [1, 0, 0, 0, 1], - [0, 1, 0, 1, 0], - [0, 1, 0, 0, 1], - ], - dtype=torch.float, - ) - - assert torch.equal(actual_feature_dense, expected_feature_dense) + actual_item_emb_bag_inputs, actual_item_offsets = cat_item_embeddings.get_item_inputs_offsets(items) + expected_item_emb_bag_inputs = torch.tensor([0, 2, 1, 4, 0, 3, 1, 2]) + expected_item_offsets = torch.tensor([0, 0, 2, 4, 6]) + assert torch.equal(actual_item_emb_bag_inputs, expected_item_emb_bag_inputs) + assert torch.equal(actual_item_offsets, expected_item_offsets) @pytest.mark.parametrize("n_factors", (10, 100)) def test_create_from_dataset(self, n_factors: int, dataset_item_features: Dataset) -> None: @@ -143,20 +123,21 @@ def test_create_from_dataset(self, n_factors: int, dataset_item_features: Datase assert isinstance(cat_item_embeddings, CatFeaturesItemNet) - actual_item_features = cat_item_embeddings.item_features - actual_n_items = cat_item_embeddings.n_items - actual_n_cat_features = cat_item_embeddings.n_cat_features - actual_embedding_dim = cat_item_embeddings.category_embeddings.embedding_dim - - expected_item_features = dataset_item_features.item_features + actual_offsets = cat_item_embeddings.offsets + actual_n_cat_feature_values = cat_item_embeddings.n_cat_feature_values + actual_embedding_dim = cat_item_embeddings.embedding_bag.embedding_dim + actual_emb_bag_inputs = cat_item_embeddings.emb_bag_inputs + actual_input_lengths = cat_item_embeddings.input_lengths - assert isinstance(expected_item_features, SparseFeatures) - expected_cat_item_features = expected_item_features.get_cat_features() + expected_offsets = torch.tensor([0, 0, 2, 4, 6, 8, 10]) + expected_emb_bag_inputs = torch.tensor([0, 2, 1, 4, 0, 3, 1, 2, 1, 3, 1, 3]) + expected_input_lengths = torch.tensor([0, 2, 2, 2, 2, 2, 2]) - assert_feature_set_equal(actual_item_features, expected_cat_item_features) - assert actual_n_items == dataset_item_features.item_id_map.size - assert actual_n_cat_features == len(expected_cat_item_features.names) + assert actual_n_cat_feature_values == 5 assert actual_embedding_dim == n_factors + assert torch.equal(actual_offsets, expected_offsets) + assert torch.equal(actual_emb_bag_inputs, expected_emb_bag_inputs) + assert torch.equal(actual_input_lengths, expected_input_lengths) @pytest.mark.parametrize( "n_items,n_factors", diff --git a/tests/models/nn/test_transformer_base.py b/tests/models/nn/test_transformer_base.py index df6f2c25..1344e7a7 100644 --- a/tests/models/nn/test_transformer_base.py +++ b/tests/models/nn/test_transformer_base.py @@ -26,10 +26,11 @@ from rectools import Columns from rectools.dataset import Dataset from rectools.models import BERT4RecModel, SASRecModel, load_model -from rectools.models.nn.item_net import IdEmbeddingsItemNet +from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet from rectools.models.nn.transformer_base import TransformerModelBase from tests.models.utils import assert_save_load_do_not_change_model +from ..data import INTERACTIONS from .utils import custom_trainer, leave_one_out_mask @@ -68,6 +69,38 @@ def interactions_df(self) -> pd.DataFrame: def dataset(self, interactions_df: pd.DataFrame) -> Dataset: return Dataset.construct(interactions_df) + @pytest.fixture + def dataset_item_features(self) -> Dataset: + item_features = pd.DataFrame( + [ + [12, "f1", "f1val1"], + [12, "f2", "f2val2"], + [13, "f1", "f1val1"], + [13, "f2", "f2val3"], + [14, "f1", "f1val2"], + [14, "f2", "f2val1"], + [15, "f1", "f1val2"], + [15, "f2", "f2val2"], + [17, "f1", "f1val2"], + [17, "f2", "f2val3"], + [16, "f1", "f1val2"], + [16, "f2", "f2val3"], + [12, "f3", 1], + [13, "f3", 2], + [14, "f3", 3], + [15, "f3", 4], + [17, "f3", 5], + [16, "f3", 6], + ], + columns=["id", "feature", "value"], + ) + ds = Dataset.construct( + INTERACTIONS, + item_features_df=item_features, + cat_item_features=["f1", "f2"], + ) + return ds + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) @pytest.mark.parametrize("default_trainer", (True, False)) def test_save_load_for_unfitted_model( @@ -75,7 +108,7 @@ def test_save_load_for_unfitted_model( ) -> None: config = { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet,), # TODO: add CatFeaturesItemNet + "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), } if not default_trainer: config["get_trainer_func"] = custom_trainer @@ -106,29 +139,33 @@ def _assert_same_reco(self, model_1: TransformerModelBase, model_2: TransformerM @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) @pytest.mark.parametrize("default_trainer", (True, False)) def test_save_load_for_fitted_model( - self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset, default_trainer: bool, trainer: Trainer + self, + model_cls: tp.Type[TransformerModelBase], + dataset_item_features: Dataset, + default_trainer: bool, + trainer: Trainer, ) -> None: config = { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet,), # TODO: add CatFeaturesItemNet + "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), } if not default_trainer: config["get_trainer_func"] = custom_trainer model = model_cls.from_config(config) - model.fit(dataset) - assert_save_load_do_not_change_model(model, dataset) + model.fit(dataset_item_features) + assert_save_load_do_not_change_model(model, dataset_item_features) @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) def test_load_from_checkpoint( self, model_cls: tp.Type[TransformerModelBase], tmp_path: str, - dataset: Dataset, + dataset_item_features: Dataset, ) -> None: model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet,), # TODO: add CatFeaturesItemNet + "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), } ) model._trainer = Trainer( # pylint: disable=protected-access @@ -140,7 +177,7 @@ def test_load_from_checkpoint( devices=1, callbacks=ModelCheckpoint(filename="last_epoch"), ) - model.fit(dataset) + model.fit(dataset_item_features) assert model.fit_trainer is not None if model.fit_trainer.log_dir is None: @@ -150,7 +187,7 @@ def test_load_from_checkpoint( recovered_model = model_cls.load_from_checkpoint(ckpt_path) assert isinstance(recovered_model, model_cls) - self._assert_same_reco(model, recovered_model, dataset) + self._assert_same_reco(model, recovered_model, dataset_item_features) @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) @pytest.mark.parametrize("verbose", (1, 0))