From ad3d41c356b5753f6c005bb12be931167bcc2003 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Sun, 2 Feb 2025 18:49:49 +0300 Subject: [PATCH 1/5] added embedding_bag --- examples/sasrec_metrics_comp.ipynb | 1474 ++++++++++++++++------------ rectools/models/nn/item_net.py | 67 +- 2 files changed, 905 insertions(+), 636 deletions(-) diff --git a/examples/sasrec_metrics_comp.ipynb b/examples/sasrec_metrics_comp.ipynb index 1e7df0e4..a144af86 100644 --- a/examples/sasrec_metrics_comp.ipynb +++ b/examples/sasrec_metrics_comp.ipynb @@ -5,6 +5,16 @@ "execution_count": 2, "metadata": {}, "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"../\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], "source": [ "import logging\n", "import os\n", @@ -28,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -120,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ @@ -129,20 +139,28 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "# Process item features to the form of a flatten dataframe\n", "items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()\n", + "\n", "items[\"genre\"] = items[\"genres\"].str.lower().str.replace(\", \", \",\", regex=False).str.split(\",\")\n", "genre_feature = items[[\"item_id\", \"genre\"]].explode(\"genre\")\n", "genre_feature.columns = [\"id\", \"value\"]\n", "genre_feature[\"feature\"] = \"genre\"\n", + "\n", + "items[\"director\"] = items[\"directors\"].str.lower().str.replace(\" \", \"\", regex=False).replace(\", \", \",\", regex=False).str.split(\",\")\n", + "directors_feature = items[[\"item_id\", \"director\"]].explode(\"director\")\n", + "directors_feature.columns = [\"id\", \"value\"]\n", + "directors_feature[\"feature\"] = \"director\"\n", + "\n", "content_feature = items.reindex(columns=[Columns.Item, \"content_type\"])\n", "content_feature.columns = [\"id\", \"value\"]\n", "content_feature[\"feature\"] = \"content_type\"\n", - "item_features = pd.concat((genre_feature, content_feature))\n", + "item_features_genre_content = pd.concat((genre_feature, content_feature))\n", + "item_features_genre_director = pd.concat((genre_feature, directors_feature))\n", "\n", "candidate_items = interactions['item_id'].drop_duplicates().astype(int)\n", "test[\"user_id\"] = test[\"user_id\"].astype(int)\n", @@ -163,8 +181,14 @@ "\n", "dataset_item_features = Dataset.construct(\n", " interactions_df=train,\n", - " item_features_df=item_features,\n", + " item_features_df=item_features_genre_content,\n", " cat_item_features=[\"genre\", \"content_type\"],\n", + ")\n", + "\n", + "dataset_item_features_genre_director = Dataset.construct(\n", + " interactions_df=train,\n", + " item_features_df=item_features_genre_director,\n", + " cat_item_features=[\"genre\", \"director\"],\n", ")" ] }, @@ -237,7 +261,18 @@ "cell_type": "code", "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n" + ] + } + ], "source": [ "model = SASRecModel(\n", " n_blocks=2,\n", @@ -247,7 +282,6 @@ " verbose=1,\n", " deterministic=True,\n", " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - " recommend_device=\"cuda\",\n", ")\n" ] }, @@ -260,24 +294,28 @@ "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", + " unq_values = pd.unique(values)\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 927 K \n", - "---------------------------------------------------------------\n", - "927 K Trainable params\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------------------------\n", + "0 | torch_model | TransformerBasedSessionEncoder | 2.2 M | train\n", + "-----------------------------------------------------------------------\n", + "2.2 M Trainable params\n", "0 Non-trainable params\n", - "927 K Total params\n", - "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "2.2 M Total params\n", + "8.991 Total estimated model params size (MB)\n", + "36 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3040ffdf32ad4a8a9d7d55de6524f3ad", + "model_id": "2503c49e4f98450c91d2b1cfd0abbe8c", "version_major": 2, "version_minor": 0 }, @@ -299,14 +337,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 5min 9s, sys: 6.07 s, total: 5min 15s\n", - "Wall time: 5min\n" + "CPU times: user 6min 12s, sys: 7.72 s, total: 6min 20s\n", + "Wall time: 6min 9s\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 14, @@ -328,25 +366,24 @@ "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e00badfab5564eb48c4bcde5aeae378a", + "model_id": "cf1fbcbcf8af474db553209f7e279c63", "version_major": 2, "version_minor": 0 }, @@ -361,8 +398,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 26.4 s, sys: 3.84 s, total: 30.2 s\n", - "Wall time: 21 s\n" + "CPU times: user 2min 27s, sys: 5.38 s, total: 2min 32s\n", + "Wall time: 19.2 s\n" ] } ], @@ -396,15 +433,15 @@ { "data": { "text/plain": [ - "[{'MAP@1': 0.048488013974905805,\n", - " 'MAP@5': 0.0821411524595932,\n", - " 'MAP@10': 0.0912577746921091,\n", - " 'MIUF@1': 3.868676865340589,\n", - " 'MIUF@5': 4.506241791317061,\n", - " 'MIUF@10': 5.087812416018942,\n", - " 'Serendipity@1': 0.0010587683094952943,\n", - " 'Serendipity@5': 0.0008279085147448243,\n", - " 'Serendipity@10': 0.0007506236395264775,\n", + "[{'MAP@1': 0.04846577699474078,\n", + " 'MAP@5': 0.0816953145406517,\n", + " 'MAP@10': 0.09070442769366964,\n", + " 'MIUF@1': 3.871426206344739,\n", + " 'MIUF@5': 4.573068555853547,\n", + " 'MIUF@10': 5.159742458558834,\n", + " 'Serendipity@1': 0.001116687417059873,\n", + " 'Serendipity@5': 0.0008645696959881002,\n", + " 'Serendipity@10': 0.0007632648657992071,\n", " 'model': 'softmax'}]" ] }, @@ -454,35 +491,35 @@ " 0\n", " 73446\n", " 9728\n", - " 2.073619\n", + " 2.401881\n", " 1\n", " \n", " \n", " 1\n", " 73446\n", " 7793\n", - " 2.032206\n", + " 1.923069\n", " 2\n", " \n", " \n", " 2\n", " 73446\n", " 3784\n", - " 1.862259\n", + " 1.824613\n", " 3\n", " \n", " \n", " 3\n", " 73446\n", - " 12965\n", - " 1.821924\n", + " 3182\n", + " 1.666528\n", " 4\n", " \n", " \n", " 4\n", " 73446\n", - " 3182\n", - " 1.760829\n", + " 7829\n", + " 1.662176\n", " 5\n", " \n", " \n", @@ -495,36 +532,36 @@ " \n", " 947045\n", " 857162\n", - " 6809\n", - " 2.055156\n", + " 12995\n", + " 2.385432\n", " 6\n", " \n", " \n", " 947046\n", " 857162\n", - " 12995\n", - " 1.811543\n", + " 6809\n", + " 2.360935\n", " 7\n", " \n", " \n", " 947047\n", " 857162\n", - " 15362\n", - " 1.524744\n", + " 657\n", + " 1.940931\n", " 8\n", " \n", " \n", " 947048\n", " 857162\n", - " 4495\n", - " 1.470260\n", + " 4702\n", + " 1.866479\n", " 9\n", " \n", " \n", " 947049\n", " 857162\n", - " 47\n", - " 1.443814\n", + " 16447\n", + " 1.758027\n", " 10\n", " \n", " \n", @@ -534,17 +571,17 @@ ], "text/plain": [ " user_id item_id score rank\n", - "0 73446 9728 2.073619 1\n", - "1 73446 7793 2.032206 2\n", - "2 73446 3784 1.862259 3\n", - "3 73446 12965 1.821924 4\n", - "4 73446 3182 1.760829 5\n", + "0 73446 9728 2.401881 1\n", + "1 73446 7793 1.923069 2\n", + "2 73446 3784 1.824613 3\n", + "3 73446 3182 1.666528 4\n", + "4 73446 7829 1.662176 5\n", "... ... ... ... ...\n", - "947045 857162 6809 2.055156 6\n", - "947046 857162 12995 1.811543 7\n", - "947047 857162 15362 1.524744 8\n", - "947048 857162 4495 1.470260 9\n", - "947049 857162 47 1.443814 10\n", + "947045 857162 12995 2.385432 6\n", + "947046 857162 6809 2.360935 7\n", + "947047 857162 657 1.940931 8\n", + "947048 857162 4702 1.866479 9\n", + "947049 857162 16447 1.758027 10\n", "\n", "[947050 rows x 4 columns]" ] @@ -598,7 +635,17 @@ "cell_type": "code", "execution_count": 21, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], "source": [ "model = SASRecModel(\n", " n_blocks=2,\n", @@ -622,24 +669,28 @@ "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", + " unq_values = pd.unique(values)\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 927 K \n", - "---------------------------------------------------------------\n", - "927 K Trainable params\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------------------------\n", + "0 | torch_model | TransformerBasedSessionEncoder | 2.2 M | train\n", + "-----------------------------------------------------------------------\n", + "2.2 M Trainable params\n", "0 Non-trainable params\n", - "927 K Total params\n", - "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "2.2 M Total params\n", + "8.991 Total estimated model params size (MB)\n", + "36 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0ef5450c14a8497983690b871b6f2a34", + "model_id": "5dc9b5b0a50c416d97913c9f4b1d8c70", "version_major": 2, "version_minor": 0 }, @@ -661,14 +712,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 5min 14s, sys: 7.46 s, total: 5min 21s\n", - "Wall time: 5min 10s\n" + "CPU times: user 5min 41s, sys: 7.63 s, total: 5min 48s\n", + "Wall time: 5min 33s\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 22, @@ -690,25 +741,24 @@ "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9192a96923c24843b92e1c7e2ab184d7", + "model_id": "ca2c4093c5464072a1affed53e39757c", "version_major": 2, "version_minor": 0 }, @@ -723,8 +773,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 24.9 s, sys: 3.25 s, total: 28.1 s\n", - "Wall time: 20.1 s\n" + "CPU times: user 2min 53s, sys: 6.47 s, total: 2min 59s\n", + "Wall time: 21.1 s\n" ] } ], @@ -786,36 +836,36 @@ " \n", " 0\n", " 73446\n", - " 9728\n", - " 3.905239\n", + " 3182\n", + " 3.370286\n", " 1\n", " \n", " \n", " 1\n", " 73446\n", - " 7829\n", - " 3.675018\n", + " 12965\n", + " 3.088001\n", " 2\n", " \n", " \n", " 2\n", " 73446\n", - " 7793\n", - " 3.606547\n", + " 6774\n", + " 3.056905\n", " 3\n", " \n", " \n", " 3\n", " 73446\n", - " 3784\n", - " 3.505795\n", + " 16270\n", + " 2.966968\n", " 4\n", " \n", " \n", " 4\n", " 73446\n", - " 6646\n", - " 3.314492\n", + " 7582\n", + " 2.965708\n", " 5\n", " \n", " \n", @@ -828,36 +878,36 @@ " \n", " 947045\n", " 857162\n", - " 3734\n", - " 3.120592\n", + " 4151\n", + " 2.733006\n", " 6\n", " \n", " \n", " 947046\n", " 857162\n", - " 12995\n", - " 2.934546\n", + " 142\n", + " 2.687315\n", " 7\n", " \n", " \n", " 947047\n", " 857162\n", - " 12692\n", - " 2.731019\n", + " 9728\n", + " 2.634741\n", " 8\n", " \n", " \n", " 947048\n", " 857162\n", - " 657\n", - " 2.596890\n", + " 3734\n", + " 2.558933\n", " 9\n", " \n", " \n", " 947049\n", " 857162\n", - " 15362\n", - " 2.591072\n", + " 9996\n", + " 2.479849\n", " 10\n", " \n", " \n", @@ -867,17 +917,17 @@ ], "text/plain": [ " user_id item_id score rank\n", - "0 73446 9728 3.905239 1\n", - "1 73446 7829 3.675018 2\n", - "2 73446 7793 3.606547 3\n", - "3 73446 3784 3.505795 4\n", - "4 73446 6646 3.314492 5\n", + "0 73446 3182 3.370286 1\n", + "1 73446 12965 3.088001 2\n", + "2 73446 6774 3.056905 3\n", + "3 73446 16270 2.966968 4\n", + "4 73446 7582 2.965708 5\n", "... ... ... ... ...\n", - "947045 857162 3734 3.120592 6\n", - "947046 857162 12995 2.934546 7\n", - "947047 857162 12692 2.731019 8\n", - "947048 857162 657 2.596890 9\n", - "947049 857162 15362 2.591072 10\n", + "947045 857162 4151 2.733006 6\n", + "947046 857162 142 2.687315 7\n", + "947047 857162 9728 2.634741 8\n", + "947048 857162 3734 2.558933 9\n", + "947049 857162 9996 2.479849 10\n", "\n", "[947050 rows x 4 columns]" ] @@ -931,7 +981,17 @@ "cell_type": "code", "execution_count": 28, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], "source": [ "model = SASRecModel(\n", " n_blocks=2,\n", @@ -956,24 +1016,28 @@ "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", + " unq_values = pd.unique(values)\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 927 K \n", - "---------------------------------------------------------------\n", - "927 K Trainable params\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------------------------\n", + "0 | torch_model | TransformerBasedSessionEncoder | 2.2 M | train\n", + "-----------------------------------------------------------------------\n", + "2.2 M Trainable params\n", "0 Non-trainable params\n", - "927 K Total params\n", - "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "2.2 M Total params\n", + "8.991 Total estimated model params size (MB)\n", + "36 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d493a5caa5f74735a8916101d422f022", + "model_id": "79f807c7c72945b7b63dbc911107b461", "version_major": 2, "version_minor": 0 }, @@ -995,14 +1059,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2h 33min 27s, sys: 40.9 s, total: 2h 34min 8s\n", - "Wall time: 10min 45s\n" + "CPU times: user 1h 57min 29s, sys: 32.5 s, total: 1h 58min 1s\n", + "Wall time: 10min 24s\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 29, @@ -1024,25 +1088,24 @@ "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9d9a87779d924688b7498005aa3dadc9", + "model_id": "145a981ae32e4efc83ff9759a207c331", "version_major": 2, "version_minor": 0 }, @@ -1057,8 +1120,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 25.2 s, sys: 3.15 s, total: 28.3 s\n", - "Wall time: 19.9 s\n" + "CPU times: user 2min 34s, sys: 4.91 s, total: 2min 39s\n", + "Wall time: 18.6 s\n" ] } ], @@ -1093,9 +1156,19 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 34, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], "source": [ "model = SASRecModel(\n", " n_blocks=2,\n", @@ -1105,38 +1178,41 @@ " verbose=1,\n", " deterministic=True,\n", " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - " recommend_device=\"cuda\",\n", " use_key_padding_mask=True,\n", ")\n" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", + " unq_values = pd.unique(values)\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 927 K \n", - "---------------------------------------------------------------\n", - "927 K Trainable params\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------------------------\n", + "0 | torch_model | TransformerBasedSessionEncoder | 2.2 M | train\n", + "-----------------------------------------------------------------------\n", + "2.2 M Trainable params\n", "0 Non-trainable params\n", - "927 K Total params\n", - "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "2.2 M Total params\n", + "8.991 Total estimated model params size (MB)\n", + "36 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e0205350fdad45c59e0a5a8a722ddd14", + "model_id": "7259973b494c4e7b9b6700affd991d2c", "version_major": 2, "version_minor": 0 }, @@ -1148,29 +1224,50 @@ "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "`Trainer.fit` stopped: `max_epochs=5` reached.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 5min 17s, sys: 7.15 s, total: 5min 24s\n", - "Wall time: 5min 10s\n" + "ename": "RuntimeError", + "evalue": "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m:1\u001b[0m\n", + "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:306\u001b[0m, in \u001b[0;36mModelBase.fit\u001b[0;34m(self, dataset, *args, **kwargs)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m: T, dataset: Dataset, \u001b[38;5;241m*\u001b[39margs: tp\u001b[38;5;241m.\u001b[39mAny, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: tp\u001b[38;5;241m.\u001b[39mAny) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 294\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 295\u001b[0m \u001b[38;5;124;03m Fit model.\u001b[39;00m\n\u001b[1;32m 296\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[38;5;124;03m self\u001b[39;00m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 306\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 307\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_fitted \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", + "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:737\u001b[0m, in \u001b[0;36mTransformerModelBase._fit\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 734\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_lightning_model(torch_model)\n\u001b[1;32m 736\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfit_trainer \u001b[38;5;241m=\u001b[39m deepcopy(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer)\n\u001b[0;32m--> 737\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_trainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloader\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/optim/optimizer.py:487\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 483\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 484\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 485\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 490\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/optim/optimizer.py:91\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 90\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 91\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 93\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/optim/adam.py:202\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 202\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 205\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:315\u001b[0m, in \u001b[0;36mSessionEncoderLightningModule.training_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 313\u001b[0m x, y, w \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m], batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m], batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myw\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msoftmax\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 315\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_full_catalog_logits\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 316\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_calc_softmax_loss(logits, y, w)\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBCE\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:367\u001b[0m, in \u001b[0;36mSessionEncoderLightningModule._get_full_catalog_logits\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 366\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_full_catalog_logits\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[0;32m--> 367\u001b[0m item_embs, session_embs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtorch_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 368\u001b[0m logits \u001b[38;5;241m=\u001b[39m session_embs \u001b[38;5;241m@\u001b[39m item_embs\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 369\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m logits\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:230\u001b[0m, in \u001b[0;36mTransformerBasedSessionEncoder.forward\u001b[0;34m(self, sessions)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;124;03mForward pass to get item and session embeddings.\u001b[39;00m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;124;03mGet item embeddings.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[38;5;124;03m(torch.Tensor, torch.Tensor)\u001b[39;00m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 229\u001b[0m item_embs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitem_model\u001b[38;5;241m.\u001b[39mget_all_embeddings() \u001b[38;5;66;03m# [n_items + n_item_extra_tokens, n_factors]\u001b[39;00m\n\u001b[0;32m--> 230\u001b[0m session_embs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode_sessions\u001b[49m\u001b[43m(\u001b[49m\u001b[43msessions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mitem_embs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# [batch_size, session_max_len, n_factors]\u001b[39;00m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m item_embs, session_embs\n", + "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:205\u001b[0m, in \u001b[0;36mTransformerBasedSessionEncoder.encode_sessions\u001b[0;34m(self, sessions, item_embs)\u001b[0m\n\u001b[1;32m 203\u001b[0m key_padding_mask \u001b[38;5;241m=\u001b[39m sessions \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attn_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;66;03m# merge masks to prevent nan gradients for torch < 2.5.0\u001b[39;00m\n\u001b[0;32m--> 205\u001b[0m attn_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_merge_masks\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseqs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m key_padding_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 208\u001b[0m seqs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransformer_layers(seqs, timeline_mask, attn_mask, key_padding_mask)\n", + "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:163\u001b[0m, in \u001b[0;36mTransformerBasedSessionEncoder._merge_masks\u001b[0;34m(self, attn_mask, key_padding_mask, query)\u001b[0m\n\u001b[1;32m 155\u001b[0m attn_mask_expanded \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_convert_mask_to_float(attn_mask, query) \u001b[38;5;66;03m# [session_max_len, session_max_len]\u001b[39;00m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m1\u001b[39m, seq_len, seq_len)\n\u001b[1;32m 158\u001b[0m \u001b[38;5;241m.\u001b[39mexpand(batch_size, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 159\u001b[0m ) \u001b[38;5;66;03m# [batch_size, session_max_len, session_max_len]\u001b[39;00m\n\u001b[1;32m 161\u001b[0m merged_mask \u001b[38;5;241m=\u001b[39m attn_mask_expanded \u001b[38;5;241m+\u001b[39m key_padding_mask_expanded\n\u001b[1;32m 162\u001b[0m res \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 163\u001b[0m \u001b[43mmerged_mask\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseq_len\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseq_len\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpand\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_heads\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 165\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseq_len\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseq_len\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m ) \u001b[38;5;66;03m# [batch_size * n_heads, session_max_len, session_max_len]\u001b[39;00m\n\u001b[1;32m 167\u001b[0m torch\u001b[38;5;241m.\u001b[39mdiagonal(res, dim1\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, dim2\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mzero_()\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m res\n", + "\u001b[0;31mRuntimeError\u001b[0m: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead." ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -1180,7 +1277,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1238,7 +1335,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1247,7 +1344,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1257,7 +1354,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -1307,73 +1404,59 @@ " \n", " \n", " softmax\n", - " 0.048488\n", - " 0.082141\n", - " 0.091258\n", - " 3.868677\n", - " 4.506242\n", - " 5.087812\n", - " 0.001059\n", - " 0.000828\n", - " 0.000751\n", - " \n", - " \n", - " softmax_padding_mask\n", - " 0.046903\n", - " 0.080840\n", - " 0.089952\n", - " 4.044813\n", - " 4.601506\n", - " 5.091948\n", - " 0.001064\n", - " 0.000822\n", - " 0.000720\n", + " 0.048466\n", + " 0.081695\n", + " 0.090704\n", + " 3.871426\n", + " 4.573069\n", + " 5.159742\n", + " 0.001117\n", + " 0.000865\n", + " 0.000763\n", " \n", " \n", " gBCE\n", - " 0.046173\n", - " 0.080317\n", - " 0.089186\n", - " 3.186794\n", - " 3.822623\n", - " 4.538796\n", - " 0.000640\n", - " 0.000517\n", - " 0.000505\n", + " 0.040848\n", + " 0.072356\n", + " 0.080166\n", + " 2.332397\n", + " 3.093763\n", + " 3.942205\n", + " 0.000103\n", + " 0.000118\n", + " 0.000134\n", " \n", " \n", " bce\n", - " 0.043094\n", - " 0.073984\n", - " 0.082944\n", - " 3.598352\n", - " 4.304013\n", - " 4.916007\n", - " 0.000477\n", - " 0.000486\n", - " 0.000502\n", + " 0.027035\n", + " 0.051244\n", + " 0.059080\n", + " 3.882081\n", + " 4.384314\n", + " 4.734298\n", + " 0.000104\n", + " 0.000121\n", + " 0.000131\n", " \n", " \n", "\n", "" ], "text/plain": [ - " MAP@1 MAP@5 MAP@10 MIUF@1 MIUF@5 \\\n", - "model \n", - "softmax 0.048488 0.082141 0.091258 3.868677 4.506242 \n", - "softmax_padding_mask 0.046903 0.080840 0.089952 4.044813 4.601506 \n", - "gBCE 0.046173 0.080317 0.089186 3.186794 3.822623 \n", - "bce 0.043094 0.073984 0.082944 3.598352 4.304013 \n", + " MAP@1 MAP@5 MAP@10 MIUF@1 MIUF@5 MIUF@10 \\\n", + "model \n", + "softmax 0.048466 0.081695 0.090704 3.871426 4.573069 5.159742 \n", + "gBCE 0.040848 0.072356 0.080166 2.332397 3.093763 3.942205 \n", + "bce 0.027035 0.051244 0.059080 3.882081 4.384314 4.734298 \n", "\n", - " MIUF@10 Serendipity@1 Serendipity@5 Serendipity@10 \n", - "model \n", - "softmax 5.087812 0.001059 0.000828 0.000751 \n", - "softmax_padding_mask 5.091948 0.001064 0.000822 0.000720 \n", - "gBCE 4.538796 0.000640 0.000517 0.000505 \n", - "bce 4.916007 0.000477 0.000486 0.000502 " + " Serendipity@1 Serendipity@5 Serendipity@10 \n", + "model \n", + "softmax 0.001117 0.000865 0.000763 \n", + "gBCE 0.000103 0.000118 0.000134 \n", + "bce 0.000104 0.000121 0.000131 " ] }, - "execution_count": 38, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -1396,9 +1479,19 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 38, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], "source": [ "model = SASRecModel(\n", " n_blocks=2,\n", @@ -1413,31 +1506,35 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", + " unq_values = pd.unique(values)\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 927 K \n", - "---------------------------------------------------------------\n", - "927 K Trainable params\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------------------------\n", + "0 | torch_model | TransformerBasedSessionEncoder | 2.2 M | train\n", + "-----------------------------------------------------------------------\n", + "2.2 M Trainable params\n", "0 Non-trainable params\n", - "927 K Total params\n", - "3.709 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "2.2 M Total params\n", + "8.991 Total estimated model params size (MB)\n", + "36 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9f638679084f411ba50da747f7635338", + "model_id": "14f1a094052c406ab577da3a5fef8f69", "version_major": 2, "version_minor": 0 }, @@ -1459,17 +1556,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 5min 21s, sys: 8.59 s, total: 5min 29s\n", - "Wall time: 5min 14s\n" + "CPU times: user 6min 18s, sys: 7.6 s, total: 6min 26s\n", + "Wall time: 6min 11s\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 41, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" } @@ -1481,32 +1578,31 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "18af83b15a1041de92fab2cdb62379ab", + "model_id": "a11ad50979364f5985370f3942eb9c8c", "version_major": 2, "version_minor": 0 }, @@ -1521,8 +1617,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 27.6 s, sys: 3.45 s, total: 31.1 s\n", - "Wall time: 22.3 s\n" + "CPU times: user 2min 35s, sys: 4.93 s, total: 2min 40s\n", + "Wall time: 23.4 s\n" ] } ], @@ -1539,13 +1635,10 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ - "# TODO: drop `apply(str)`\n", - "recos[\"item_id\"] = recos[\"item_id\"].apply(str)\n", - "test[\"item_id\"] = test[\"item_id\"].astype(str)\n", "metric_values = calc_metrics(metrics, recos[[\"user_id\", \"item_id\", \"rank\"]], test, train, catalog)\n", "metric_values[\"model\"] = \"sasrec_ids\"\n", "features_results.append(metric_values)" @@ -1560,9 +1653,19 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 43, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], "source": [ "model = SASRecModel(\n", " n_blocks=2,\n", @@ -1577,31 +1680,35 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", + " unq_values = pd.unique(values)\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 935 K \n", - "---------------------------------------------------------------\n", - "935 K Trainable params\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------------------------\n", + "0 | torch_model | TransformerBasedSessionEncoder | 3.4 M | train\n", + "-----------------------------------------------------------------------\n", + "3.4 M Trainable params\n", "0 Non-trainable params\n", - "935 K Total params\n", - "3.742 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "3.4 M Total params\n", + "13.621 Total estimated model params size (MB)\n", + "39 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1217d0ed8fac4404aca69428054e767f", + "model_id": "0184c3daa298402da6b3ec16aa38a5dc", "version_major": 2, "version_minor": 0 }, @@ -1612,14 +1719,6 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/dataset/features.py:424: UserWarning: Converting sparse features to dense array may cause MemoryError\n", - " warnings.warn(\"Converting sparse features to dense array may cause MemoryError\")\n" - ] - }, { "name": "stderr", "output_type": "stream", @@ -1630,47 +1729,46 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 45, + "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#%%time\n", - "model.fit(dataset_item_features)" + "model.fit(dataset_item_features_genre_director)" ] }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b66c4dbd535f4601af3e8fb91b0de523", + "model_id": "2b68d1c5bf7543c599a45b39190b1fbc", "version_major": 2, "version_minor": 0 }, @@ -1685,8 +1783,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 28.9 s, sys: 3.6 s, total: 32.5 s\n", - "Wall time: 22.2 s\n" + "CPU times: user 2min 37s, sys: 8.99 s, total: 2min 46s\n", + "Wall time: 21.8 s\n" ] } ], @@ -1703,15 +1801,12 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ - "# TODO: drop `apply(str)`\n", - "recos[\"item_id\"] = recos[\"item_id\"].apply(str)\n", - "test[\"item_id\"] = test[\"item_id\"].astype(str)\n", "metric_values = calc_metrics(metrics, recos[[\"user_id\", \"item_id\", \"rank\"]], test, train, catalog)\n", - "metric_values[\"model\"] = \"sasrec_ids_cat\"\n", + "metric_values[\"model\"] = \"sasrec_id_and_cat_features\"\n", "features_results.append(metric_values)" ] }, @@ -1724,9 +1819,19 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 52, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], "source": [ "model = SASRecModel(\n", " n_blocks=2,\n", @@ -1741,31 +1846,35 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", + " unq_values = pd.unique(values)\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 211 K \n", - "---------------------------------------------------------------\n", - "211 K Trainable params\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------------------------\n", + "0 | torch_model | TransformerBasedSessionEncoder | 2.0 M | train\n", + "-----------------------------------------------------------------------\n", + "2.0 M Trainable params\n", "0 Non-trainable params\n", - "211 K Total params\n", - "0.847 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "2.0 M Total params\n", + "7.832 Total estimated model params size (MB)\n", + "36 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d1ceac774c614b85bc31d5bd3af6f75c", + "model_id": "54489a8a6a27455bb0d2b7c9588f0ec8", "version_major": 2, "version_minor": 0 }, @@ -1780,55 +1889,52 @@ "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/dataset/features.py:424: UserWarning: Converting sparse features to dense array may cause MemoryError\n", - " warnings.warn(\"Converting sparse features to dense array may cause MemoryError\")\n", "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 50, + "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#%%time\n", - "model.fit(dataset_item_features)" + "model.fit(dataset_item_features_genre_director)" ] }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bddb0ea860994c5e9eafb49b598756a0", + "model_id": "ec97ecaa85eb4c3ead46f112d1acf090", "version_major": 2, "version_minor": 0 }, @@ -1843,8 +1949,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 24.3 s, sys: 2.53 s, total: 26.9 s\n", - "Wall time: 19.2 s\n" + "CPU times: user 2min 31s, sys: 8.96 s, total: 2min 40s\n", + "Wall time: 18.7 s\n" ] } ], @@ -1861,99 +1967,86 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ - "# TODO: drop `apply(str)`\n", - "recos[\"item_id\"] = recos[\"item_id\"].apply(str)\n", - "test[\"item_id\"] = test[\"item_id\"].astype(str)\n", "metric_values = calc_metrics(metrics, recos[[\"user_id\", \"item_id\", \"rank\"]], test, train, catalog)\n", - "metric_values[\"model\"] = \"sasrec_cat\"\n", + "metric_values[\"model\"] = \"sasrec_cat_features\"\n", "features_results.append(metric_values)" ] }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[{'MAP@1': 0.048488013974905805,\n", - " 'MAP@5': 0.0821411524595932,\n", - " 'MAP@10': 0.0912577746921091,\n", - " 'MIUF@1': 3.868676865340589,\n", - " 'MIUF@5': 4.506241791317061,\n", - " 'MIUF@10': 5.087812416018942,\n", - " 'Serendipity@1': 0.0010587683094952943,\n", - " 'Serendipity@5': 0.0008279085147448243,\n", - " 'Serendipity@10': 0.0007506236395264775,\n", + "[{'MAP@1': 0.04846577699474078,\n", + " 'MAP@5': 0.0816953145406517,\n", + " 'MAP@10': 0.09070442769366964,\n", + " 'MIUF@1': 3.871426206344739,\n", + " 'MIUF@5': 4.573068555853547,\n", + " 'MIUF@10': 5.159742458558834,\n", + " 'Serendipity@1': 0.001116687417059873,\n", + " 'Serendipity@5': 0.0008645696959881002,\n", + " 'Serendipity@10': 0.0007632648657992071,\n", " 'model': 'softmax'},\n", - " {'MAP@1': 0.043093656973408605,\n", - " 'MAP@5': 0.07398385923675324,\n", - " 'MAP@10': 0.08294397368660968,\n", - " 'MIUF@1': 3.598352040478207,\n", - " 'MIUF@5': 4.304012747602231,\n", - " 'MIUF@10': 4.916007085918255,\n", - " 'Serendipity@1': 0.0004767361940337927,\n", - " 'Serendipity@5': 0.0004863863900258476,\n", - " 'Serendipity@10': 0.000501558401864858,\n", + " {'MAP@1': 0.02703450310364319,\n", + " 'MAP@5': 0.05124396949349954,\n", + " 'MAP@10': 0.05907958022653049,\n", + " 'MIUF@1': 3.882081042459438,\n", + " 'MIUF@5': 4.384313936251787,\n", + " 'MIUF@10': 4.734298278984563,\n", + " 'Serendipity@1': 0.00010437879417622002,\n", + " 'Serendipity@5': 0.0001209341551851975,\n", + " 'Serendipity@10': 0.0001308852660453074,\n", " 'model': 'bce'},\n", - " {'MAP@1': 0.046172701078265196,\n", - " 'MAP@5': 0.08031691396493008,\n", - " 'MAP@10': 0.08918565294041045,\n", - " 'MIUF@1': 3.1867943881001892,\n", - " 'MIUF@5': 3.822622559312681,\n", - " 'MIUF@10': 4.538795926216838,\n", - " 'Serendipity@1': 0.0006401436142825035,\n", - " 'Serendipity@5': 0.0005169224313850365,\n", - " 'Serendipity@10': 0.0005046814909982423,\n", + " {'MAP@1': 0.04084812884382748,\n", + " 'MAP@5': 0.07235604259743772,\n", + " 'MAP@10': 0.08016616686270196,\n", + " 'MIUF@1': 2.33239724771057,\n", + " 'MIUF@5': 3.093763291371006,\n", + " 'MIUF@10': 3.9422054591506033,\n", + " 'Serendipity@1': 0.00010303205538126172,\n", + " 'Serendipity@5': 0.00011795153034776448,\n", + " 'Serendipity@10': 0.00013442022189753792,\n", " 'model': 'gBCE'},\n", - " {'MAP@1': 0.04690299592170011,\n", - " 'MAP@5': 0.08084021275405333,\n", - " 'MAP@10': 0.08995189251861414,\n", - " 'MIUF@1': 4.044813264146811,\n", - " 'MIUF@5': 4.601506122113505,\n", - " 'MIUF@10': 5.091947950604172,\n", - " 'Serendipity@1': 0.0010635295261226848,\n", - " 'Serendipity@5': 0.0008217942085431382,\n", - " 'Serendipity@10': 0.0007197978905568893,\n", - " 'model': 'softmax_padding_mask'},\n", - " {'MAP@1': 0.04869646365671735,\n", - " 'MAP@5': 0.08268167732074286,\n", - " 'MAP@10': 0.09205364831157593,\n", + " {'MAP@1': 0.048147530079556855,\n", + " 'MAP@5': 0.08174752632088274,\n", + " 'MAP@10': 0.09082809872497208,\n", " 'MIUF@1': 18.824620072061013,\n", " 'MIUF@5': 18.824620072061013,\n", " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.10001583865688189,\n", - " 'Serendipity@5': 0.06081037195832809,\n", - " 'Serendipity@10': 0.04484003521581858,\n", + " 'Serendipity@1': 0.09930837864949052,\n", + " 'Serendipity@5': 0.06018338155520556,\n", + " 'Serendipity@10': 0.04419324756562254,\n", " 'model': 'sasrec_ids'},\n", - " {'MAP@1': 0.04735485773892377,\n", - " 'MAP@5': 0.08026077112292192,\n", - " 'MAP@10': 0.08941884083577493,\n", - " 'MIUF@1': 18.824620072061013,\n", - " 'MIUF@5': 18.824620072061013,\n", - " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.09770339475212501,\n", - " 'Serendipity@5': 0.059226925047999514,\n", - " 'Serendipity@10': 0.04397441129046034,\n", - " 'model': 'sasrec_ids_cat'},\n", - " {'MAP@1': 0.0016025375062613473,\n", - " 'MAP@5': 0.005957035647418842,\n", - " 'MAP@10': 0.006956210940042861,\n", - " 'MIUF@1': 18.824620072061013,\n", - " 'MIUF@5': 18.824620072061013,\n", - " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.005078929306794783,\n", - " 'Serendipity@5': 0.006295511867818093,\n", - " 'Serendipity@10': 0.005134648242893923,\n", - " 'model': 'sasrec_cat'}]" + " {'MAP@1': 0.04766175619468918,\n", + " 'MAP@5': 0.08225834198336934,\n", + " 'MAP@10': 0.09147359309311318,\n", + " 'MIUF@1': 3.9415828408413325,\n", + " 'MIUF@5': 4.572004200795584,\n", + " 'MIUF@10': 5.181503181696694,\n", + " 'Serendipity@1': 0.0012604074213600975,\n", + " 'Serendipity@5': 0.0009317855973533245,\n", + " 'Serendipity@10': 0.0008227161922729454,\n", + " 'model': 'sasrec_id_and_cat_features'},\n", + " {'MAP@1': 0.04310550975326949,\n", + " 'MAP@5': 0.07036672111056717,\n", + " 'MAP@10': 0.07822710009694937,\n", + " 'MIUF@1': 4.18466637016627,\n", + " 'MIUF@5': 5.596602216823304,\n", + " 'MIUF@10': 6.135927367440625,\n", + " 'Serendipity@1': 0.0010273693092966185,\n", + " 'Serendipity@5': 0.00088834477237887,\n", + " 'Serendipity@10': 0.0007727660052644417,\n", + " 'model': 'sasrec_cat_features'}]" ] }, - "execution_count": 53, + "execution_count": 56, "metadata": {}, "output_type": "execute_result" } @@ -1962,6 +2055,175 @@ "features_results" ] }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MAP@1MAP@5MAP@10MIUF@1MIUF@5MIUF@10Serendipity@1Serendipity@5Serendipity@10
model
sasrec_id_and_cat_features0.0476620.0822580.0914743.9415834.5720045.1815030.0012600.0009320.000823
sasrec_ids0.0481480.0817480.09082818.82462018.82462018.8246200.0993080.0601830.044193
softmax0.0484660.0816950.0907043.8714264.5730695.1597420.0011170.0008650.000763
gBCE0.0408480.0723560.0801662.3323973.0937633.9422050.0001030.0001180.000134
sasrec_cat_features0.0431060.0703670.0782274.1846665.5966026.1359270.0010270.0008880.000773
bce0.0270350.0512440.0590803.8820814.3843144.7342980.0001040.0001210.000131
\n", + "
" + ], + "text/plain": [ + " MAP@1 MAP@5 MAP@10 MIUF@1 \\\n", + "model \n", + "sasrec_id_and_cat_features 0.047662 0.082258 0.091474 3.941583 \n", + "sasrec_ids 0.048148 0.081748 0.090828 18.824620 \n", + "softmax 0.048466 0.081695 0.090704 3.871426 \n", + "gBCE 0.040848 0.072356 0.080166 2.332397 \n", + "sasrec_cat_features 0.043106 0.070367 0.078227 4.184666 \n", + "bce 0.027035 0.051244 0.059080 3.882081 \n", + "\n", + " MIUF@5 MIUF@10 Serendipity@1 \\\n", + "model \n", + "sasrec_id_and_cat_features 4.572004 5.181503 0.001260 \n", + "sasrec_ids 18.824620 18.824620 0.099308 \n", + "softmax 4.573069 5.159742 0.001117 \n", + "gBCE 3.093763 3.942205 0.000103 \n", + "sasrec_cat_features 5.596602 6.135927 0.001027 \n", + "bce 4.384314 4.734298 0.000104 \n", + "\n", + " Serendipity@5 Serendipity@10 \n", + "model \n", + "sasrec_id_and_cat_features 0.000932 0.000823 \n", + "sasrec_ids 0.060183 0.044193 \n", + "softmax 0.000865 0.000763 \n", + "gBCE 0.000118 0.000134 \n", + "sasrec_cat_features 0.000888 0.000773 \n", + "bce 0.000121 0.000131 " + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "features_df = (\n", + " pd.DataFrame(features_results)\n", + " .set_index(\"model\")\n", + " .sort_values(by=[\"MAP@10\", \"Serendipity@10\"], ascending=False)\n", + ")\n", + "features_df" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1971,7 +2233,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 58, "metadata": {}, "outputs": [], "source": [ @@ -1980,15 +2242,15 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 59, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 1.43 s, sys: 197 ms, total: 1.63 s\n", - "Wall time: 1.07 s\n" + "CPU times: user 3.14 s, sys: 4.21 s, total: 7.35 s\n", + "Wall time: 1.15 s\n" ] } ], @@ -2005,7 +2267,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 60, "metadata": {}, "outputs": [ { @@ -2039,211 +2301,211 @@ " \n", " 0\n", " 13865\n", - " 15648\n", + " 11863\n", " 1.000000\n", " 1\n", " \n", " \n", " 1\n", " 13865\n", - " 3386\n", + " 7107\n", " 1.000000\n", " 2\n", " \n", " \n", " 2\n", " 13865\n", - " 16194\n", - " 0.908769\n", + " 6409\n", + " 0.628877\n", " 3\n", " \n", " \n", " 3\n", " 13865\n", - " 147\n", - " 0.908769\n", + " 142\n", + " 0.559630\n", " 4\n", " \n", " \n", " 4\n", " 13865\n", - " 12586\n", - " 0.908769\n", + " 2657\n", + " 0.514484\n", " 5\n", " \n", " \n", " 5\n", " 13865\n", - " 12309\n", - " 0.908769\n", + " 4457\n", + " 0.503537\n", " 6\n", " \n", " \n", " 6\n", " 13865\n", - " 6661\n", - " 0.908769\n", + " 15297\n", + " 0.500209\n", " 7\n", " \n", " \n", " 7\n", " 13865\n", - " 2255\n", - " 0.908769\n", + " 6809\n", + " 0.487185\n", " 8\n", " \n", " \n", " 8\n", " 13865\n", - " 4130\n", - " 0.908769\n", + " 10772\n", + " 0.485932\n", " 9\n", " \n", " \n", " 9\n", " 13865\n", - " 9109\n", - " 0.908769\n", + " 10440\n", + " 0.473830\n", " 10\n", " \n", " \n", " 10\n", " 4457\n", - " 5109\n", - " 1.000000\n", + " 14741\n", + " 0.725818\n", " 1\n", " \n", " \n", " 11\n", " 4457\n", - " 8851\n", - " 1.000000\n", + " 12995\n", + " 0.658568\n", " 2\n", " \n", " \n", " 12\n", " 4457\n", - " 8486\n", - " 1.000000\n", + " 142\n", + " 0.646901\n", " 3\n", " \n", " \n", " 13\n", " 4457\n", - " 12087\n", - " 1.000000\n", + " 3935\n", + " 0.641437\n", " 4\n", " \n", " \n", " 14\n", " 4457\n", - " 2313\n", - " 1.000000\n", + " 10772\n", + " 0.621477\n", " 5\n", " \n", " \n", " 15\n", " 4457\n", - " 11977\n", - " 1.000000\n", + " 1287\n", + " 0.611060\n", " 6\n", " \n", " \n", " 16\n", " 4457\n", - " 3384\n", - " 1.000000\n", + " 3509\n", + " 0.610761\n", " 7\n", " \n", " \n", " 17\n", " 4457\n", - " 6285\n", - " 1.000000\n", + " 15464\n", + " 0.592197\n", " 8\n", " \n", " \n", " 18\n", " 4457\n", - " 7928\n", - " 1.000000\n", + " 274\n", + " 0.587578\n", " 9\n", " \n", " \n", " 19\n", " 4457\n", - " 11513\n", - " 1.000000\n", + " 6455\n", + " 0.583398\n", " 10\n", " \n", " \n", " 20\n", " 15297\n", - " 8723\n", - " 1.000000\n", + " 142\n", + " 0.614521\n", " 1\n", " \n", " \n", " 21\n", " 15297\n", - " 5926\n", - " 1.000000\n", + " 2657\n", + " 0.567724\n", " 2\n", " \n", " \n", " 22\n", " 15297\n", - " 4131\n", - " 1.000000\n", + " 6809\n", + " 0.554094\n", " 3\n", " \n", " \n", " 23\n", " 15297\n", - " 4229\n", - " 1.000000\n", + " 10772\n", + " 0.539985\n", " 4\n", " \n", " \n", " 24\n", " 15297\n", - " 7005\n", - " 1.000000\n", + " 10440\n", + " 0.529680\n", " 5\n", " \n", " \n", " 25\n", " 15297\n", - " 10797\n", - " 1.000000\n", + " 14337\n", + " 0.514991\n", " 6\n", " \n", " \n", " 26\n", " 15297\n", - " 10535\n", - " 1.000000\n", + " 1844\n", + " 0.514991\n", " 7\n", " \n", " \n", " 27\n", " 15297\n", - " 5400\n", - " 1.000000\n", + " 4457\n", + " 0.507450\n", " 8\n", " \n", " \n", " 28\n", " 15297\n", - " 4716\n", - " 1.000000\n", + " 13865\n", + " 0.500209\n", " 9\n", " \n", " \n", " 29\n", " 15297\n", - " 13103\n", - " 1.000000\n", + " 7107\n", + " 0.500209\n", " 10\n", " \n", " \n", @@ -2252,39 +2514,39 @@ ], "text/plain": [ " target_item_id item_id score rank\n", - "0 13865 15648 1.000000 1\n", - "1 13865 3386 1.000000 2\n", - "2 13865 16194 0.908769 3\n", - "3 13865 147 0.908769 4\n", - "4 13865 12586 0.908769 5\n", - "5 13865 12309 0.908769 6\n", - "6 13865 6661 0.908769 7\n", - "7 13865 2255 0.908769 8\n", - "8 13865 4130 0.908769 9\n", - "9 13865 9109 0.908769 10\n", - "10 4457 5109 1.000000 1\n", - "11 4457 8851 1.000000 2\n", - "12 4457 8486 1.000000 3\n", - "13 4457 12087 1.000000 4\n", - "14 4457 2313 1.000000 5\n", - "15 4457 11977 1.000000 6\n", - "16 4457 3384 1.000000 7\n", - "17 4457 6285 1.000000 8\n", - "18 4457 7928 1.000000 9\n", - "19 4457 11513 1.000000 10\n", - "20 15297 8723 1.000000 1\n", - "21 15297 5926 1.000000 2\n", - "22 15297 4131 1.000000 3\n", - "23 15297 4229 1.000000 4\n", - "24 15297 7005 1.000000 5\n", - "25 15297 10797 1.000000 6\n", - "26 15297 10535 1.000000 7\n", - "27 15297 5400 1.000000 8\n", - "28 15297 4716 1.000000 9\n", - "29 15297 13103 1.000000 10" + "0 13865 11863 1.000000 1\n", + "1 13865 7107 1.000000 2\n", + "2 13865 6409 0.628877 3\n", + "3 13865 142 0.559630 4\n", + "4 13865 2657 0.514484 5\n", + "5 13865 4457 0.503537 6\n", + "6 13865 15297 0.500209 7\n", + "7 13865 6809 0.487185 8\n", + "8 13865 10772 0.485932 9\n", + "9 13865 10440 0.473830 10\n", + "10 4457 14741 0.725818 1\n", + "11 4457 12995 0.658568 2\n", + "12 4457 142 0.646901 3\n", + "13 4457 3935 0.641437 4\n", + "14 4457 10772 0.621477 5\n", + "15 4457 1287 0.611060 6\n", + "16 4457 3509 0.610761 7\n", + "17 4457 15464 0.592197 8\n", + "18 4457 274 0.587578 9\n", + "19 4457 6455 0.583398 10\n", + "20 15297 142 0.614521 1\n", + "21 15297 2657 0.567724 2\n", + "22 15297 6809 0.554094 3\n", + "23 15297 10772 0.539985 4\n", + "24 15297 10440 0.529680 5\n", + "25 15297 14337 0.514991 6\n", + "26 15297 1844 0.514991 7\n", + "27 15297 4457 0.507450 8\n", + "28 15297 13865 0.500209 9\n", + "29 15297 7107 0.500209 10" ] }, - "execution_count": 56, + "execution_count": 60, "metadata": {}, "output_type": "execute_result" } @@ -2295,7 +2557,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 61, "metadata": {}, "outputs": [ { @@ -2330,282 +2592,282 @@ " \n", " 0\n", " 13865\n", - " 15648\n", + " 11863\n", " 1.000000\n", " 1\n", - " Черное золото\n", + " Девятаев - сериал\n", " \n", " \n", " 1\n", " 13865\n", - " 3386\n", + " 7107\n", " 1.000000\n", " 2\n", - " Спартак\n", + " Девятаев\n", " \n", " \n", " 2\n", " 13865\n", - " 16194\n", - " 0.908769\n", + " 6409\n", + " 0.628877\n", " 3\n", - " Голубая линия\n", + " Особо опасен\n", " \n", " \n", " 3\n", " 13865\n", - " 147\n", - " 0.908769\n", + " 142\n", + " 0.559630\n", " 4\n", - " Единичка\n", + " Маша\n", " \n", " \n", " 4\n", " 13865\n", - " 12586\n", - " 0.908769\n", + " 2657\n", + " 0.514484\n", " 5\n", - " Вспоминая 1942\n", + " Подслушано\n", " \n", " \n", " 5\n", " 13865\n", - " 12309\n", - " 0.908769\n", + " 4457\n", + " 0.503537\n", " 6\n", - " Враг у ворот\n", + " 2067: Петля времени\n", " \n", " \n", " 6\n", " 13865\n", - " 6661\n", - " 0.908769\n", + " 15297\n", + " 0.500209\n", " 7\n", - " Солдатик\n", + " Клиника счастья\n", " \n", " \n", " 7\n", " 13865\n", - " 2255\n", - " 0.908769\n", + " 6809\n", + " 0.487185\n", " 8\n", - " Пленный\n", + " Дуров\n", " \n", " \n", " 8\n", " 13865\n", - " 4130\n", - " 0.908769\n", + " 10772\n", + " 0.485932\n", " 9\n", - " Пустота\n", + " Зелёная книга\n", " \n", " \n", " 9\n", " 13865\n", - " 9109\n", - " 0.908769\n", + " 10440\n", + " 0.473830\n", " 10\n", - " Последняя битва\n", + " Хрустальный\n", " \n", " \n", " 10\n", " 4457\n", - " 5109\n", - " 1.000000\n", + " 14741\n", + " 0.725818\n", " 1\n", - " Время разлуки\n", + " Цвет из иных миров\n", " \n", " \n", " 11\n", " 4457\n", - " 8851\n", - " 1.000000\n", + " 12995\n", + " 0.658568\n", " 2\n", - " Лисы\n", + " Восемь сотен\n", " \n", " \n", " 12\n", " 4457\n", - " 8486\n", - " 1.000000\n", + " 142\n", + " 0.646901\n", " 3\n", - " Мой создатель\n", + " Маша\n", " \n", " \n", " 13\n", " 4457\n", - " 12087\n", - " 1.000000\n", + " 3935\n", + " 0.641437\n", " 4\n", - " Молчаливое бегство\n", + " Бывшая с того света\n", " \n", " \n", " 14\n", " 4457\n", - " 2313\n", - " 1.000000\n", + " 10772\n", + " 0.621477\n", " 5\n", - " Свет моей жизни\n", + " Зелёная книга\n", " \n", " \n", " 15\n", " 4457\n", - " 11977\n", - " 1.000000\n", + " 1287\n", + " 0.611060\n", " 6\n", - " Зоология\n", + " Терминатор: Тёмные судьбы\n", " \n", " \n", " 16\n", " 4457\n", - " 3384\n", - " 1.000000\n", + " 3509\n", + " 0.610761\n", " 7\n", - " Вивариум\n", + " Комната желаний\n", " \n", " \n", " 17\n", " 4457\n", - " 6285\n", - " 1.000000\n", + " 15464\n", + " 0.592197\n", " 8\n", - " Божественная любовь\n", + " Апгрейд\n", " \n", " \n", " 18\n", " 4457\n", - " 7928\n", - " 1.000000\n", + " 274\n", + " 0.587578\n", " 9\n", - " Вечная жизнь\n", + " Логан\n", " \n", " \n", " 19\n", " 4457\n", - " 11513\n", - " 1.000000\n", + " 6455\n", + " 0.583398\n", " 10\n", - " Любовь\n", + " Альфа\n", " \n", " \n", " 20\n", " 15297\n", - " 8723\n", - " 1.000000\n", + " 142\n", + " 0.614521\n", " 1\n", - " Секс в другом городе: Поколение Q\n", + " Маша\n", " \n", " \n", " 21\n", " 15297\n", - " 5926\n", - " 1.000000\n", + " 2657\n", + " 0.567724\n", " 2\n", - " Пациенты\n", + " Подслушано\n", " \n", " \n", " 22\n", " 15297\n", - " 4131\n", - " 1.000000\n", + " 6809\n", + " 0.554094\n", " 3\n", - " Учитель Ким, доктор Романтик\n", + " Дуров\n", " \n", " \n", " 23\n", " 15297\n", - " 4229\n", - " 1.000000\n", + " 10772\n", + " 0.539985\n", " 4\n", - " Самара\n", + " Зелёная книга\n", " \n", " \n", " 24\n", " 15297\n", - " 7005\n", - " 1.000000\n", + " 10440\n", + " 0.529680\n", " 5\n", - " Чёрная кровь\n", + " Хрустальный\n", " \n", " \n", " 25\n", " 15297\n", - " 10797\n", - " 1.000000\n", + " 14337\n", + " 0.514991\n", " 6\n", - " Наследники\n", + " [4К] Аферистка\n", " \n", " \n", " 26\n", " 15297\n", - " 10535\n", - " 1.000000\n", + " 1844\n", + " 0.514991\n", " 7\n", - " Я могу уничтожить тебя\n", + " Аферистка\n", " \n", " \n", " 27\n", " 15297\n", - " 5400\n", - " 1.000000\n", + " 4457\n", + " 0.507450\n", " 8\n", - " Частица вселенной\n", + " 2067: Петля времени\n", " \n", " \n", " 28\n", " 15297\n", - " 4716\n", - " 1.000000\n", + " 13865\n", + " 0.500209\n", " 9\n", - " Мастера секса\n", + " Девятаев\n", " \n", " \n", " 29\n", " 15297\n", - " 13103\n", - " 1.000000\n", + " 7107\n", + " 0.500209\n", " 10\n", - " Хороший доктор\n", + " Девятаев\n", " \n", " \n", "\n", "" ], "text/plain": [ - " target_item_id item_id score rank title\n", - "0 13865 15648 1.000000 1 Черное золото\n", - "1 13865 3386 1.000000 2 Спартак\n", - "2 13865 16194 0.908769 3 Голубая линия\n", - "3 13865 147 0.908769 4 Единичка\n", - "4 13865 12586 0.908769 5 Вспоминая 1942\n", - "5 13865 12309 0.908769 6 Враг у ворот\n", - "6 13865 6661 0.908769 7 Солдатик\n", - "7 13865 2255 0.908769 8 Пленный\n", - "8 13865 4130 0.908769 9 Пустота\n", - "9 13865 9109 0.908769 10 Последняя битва\n", - "10 4457 5109 1.000000 1 Время разлуки\n", - "11 4457 8851 1.000000 2 Лисы\n", - "12 4457 8486 1.000000 3 Мой создатель\n", - "13 4457 12087 1.000000 4 Молчаливое бегство\n", - "14 4457 2313 1.000000 5 Свет моей жизни\n", - "15 4457 11977 1.000000 6 Зоология\n", - "16 4457 3384 1.000000 7 Вивариум\n", - "17 4457 6285 1.000000 8 Божественная любовь\n", - "18 4457 7928 1.000000 9 Вечная жизнь\n", - "19 4457 11513 1.000000 10 Любовь\n", - "20 15297 8723 1.000000 1 Секс в другом городе: Поколение Q\n", - "21 15297 5926 1.000000 2 Пациенты\n", - "22 15297 4131 1.000000 3 Учитель Ким, доктор Романтик\n", - "23 15297 4229 1.000000 4 Самара\n", - "24 15297 7005 1.000000 5 Чёрная кровь\n", - "25 15297 10797 1.000000 6 Наследники\n", - "26 15297 10535 1.000000 7 Я могу уничтожить тебя\n", - "27 15297 5400 1.000000 8 Частица вселенной\n", - "28 15297 4716 1.000000 9 Мастера секса\n", - "29 15297 13103 1.000000 10 Хороший доктор" + " target_item_id item_id score rank title\n", + "0 13865 11863 1.000000 1 Девятаев - сериал\n", + "1 13865 7107 1.000000 2 Девятаев\n", + "2 13865 6409 0.628877 3 Особо опасен\n", + "3 13865 142 0.559630 4 Маша\n", + "4 13865 2657 0.514484 5 Подслушано\n", + "5 13865 4457 0.503537 6 2067: Петля времени\n", + "6 13865 15297 0.500209 7 Клиника счастья\n", + "7 13865 6809 0.487185 8 Дуров\n", + "8 13865 10772 0.485932 9 Зелёная книга\n", + "9 13865 10440 0.473830 10 Хрустальный\n", + "10 4457 14741 0.725818 1 Цвет из иных миров\n", + "11 4457 12995 0.658568 2 Восемь сотен\n", + "12 4457 142 0.646901 3 Маша\n", + "13 4457 3935 0.641437 4 Бывшая с того света\n", + "14 4457 10772 0.621477 5 Зелёная книга\n", + "15 4457 1287 0.611060 6 Терминатор: Тёмные судьбы\n", + "16 4457 3509 0.610761 7 Комната желаний\n", + "17 4457 15464 0.592197 8 Апгрейд\n", + "18 4457 274 0.587578 9 Логан\n", + "19 4457 6455 0.583398 10 Альфа\n", + "20 15297 142 0.614521 1 Маша\n", + "21 15297 2657 0.567724 2 Подслушано\n", + "22 15297 6809 0.554094 3 Дуров\n", + "23 15297 10772 0.539985 4 Зелёная книга\n", + "24 15297 10440 0.529680 5 Хрустальный\n", + "25 15297 14337 0.514991 6 [4К] Аферистка\n", + "26 15297 1844 0.514991 7 Аферистка\n", + "27 15297 4457 0.507450 8 2067: Петля времени\n", + "28 15297 13865 0.500209 9 Девятаев\n", + "29 15297 7107 0.500209 10 Девятаев" ] }, - "execution_count": 57, + "execution_count": 61, "metadata": {}, "output_type": "execute_result" } @@ -2625,9 +2887,9 @@ ], "metadata": { "kernelspec": { - "display_name": "rectools_origin", + "display_name": "venv", "language": "python", - "name": "rectools_origin" + "name": "python3" }, "language_info": { "codemirror_mode": { diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index 1c9c9ee8..b9c358b6 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -61,19 +61,24 @@ class CatFeaturesItemNet(ItemNetBase): def __init__( self, - item_features: SparseFeatures, + emb_bag_inputs: torch.Tensor, + len_indexes: torch.Tensor, + offsets: torch.Tensor, + n_cat_features: int, n_factors: int, dropout_rate: float, ): super().__init__() - self.item_features = item_features - self.n_items = len(item_features) - self.n_cat_features = len(item_features.names) - - self.category_embeddings = nn.Embedding(num_embeddings=self.n_cat_features, embedding_dim=n_factors) + self.n_cat_features = n_cat_features + self.embedding_bag = nn.EmbeddingBag(num_embeddings=n_cat_features, embedding_dim=n_factors, padding_idx=0) self.drop_layer = nn.Dropout(dropout_rate) + self.register_buffer("offsets", offsets) + self.register_buffer("emb_bag_indexes", torch.arange(len(emb_bag_inputs))) + self.register_buffer("emb_bag_inputs", emb_bag_inputs) + self.register_buffer("len_indexes", len_indexes) + def forward(self, items: torch.Tensor) -> torch.Tensor: """ Forward pass to get item embeddings from categorical item features. @@ -88,36 +93,26 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: torch.Tensor Item embeddings. """ - feature_dense = self.get_dense_item_features(items) - - feature_embs = self.category_embeddings(self.feature_catalog.to(self.device)) - feature_embs = self.drop_layer(feature_embs) - - feature_embeddings_per_items = feature_dense.to(self.device) @ feature_embs + item_emb_bag_inputs, item_offsets = self.get_item_inputs_offsets(items) + feature_embeddings_per_items = self.embedding_bag(input=item_emb_bag_inputs, offsets=item_offsets) + feature_embeddings_per_items = self.drop_layer(feature_embeddings_per_items) return feature_embeddings_per_items + def get_item_inputs_offsets(self, items: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Get categorical item features and offsets for `items`.""" + item_indexes = self.offsets[items].unsqueeze(-1) + self.emb_bag_indexes + length_mask = self.emb_bag_indexes < self.len_indexes[items].unsqueeze(-1) + item_emb_bag_inputs = self.emb_bag_inputs[item_indexes[length_mask]].squeeze(-1) + item_offsets = torch.cat( + (torch.tensor([0], device=self.device), torch.cumsum(self.len_indexes[items], dim=0)[:-1]) + ) + return item_emb_bag_inputs, item_offsets + @property def feature_catalog(self) -> torch.Tensor: """Return tensor with elements in range [0, n_cat_features).""" return torch.arange(0, self.n_cat_features) - def get_dense_item_features(self, items: torch.Tensor) -> torch.Tensor: - """ - Get categorical item values by certain item ids in dense format. - - Parameters - ---------- - items : torch.Tensor - Internal item ids. - - Returns - ------- - torch.Tensor - categorical item values in dense format. - """ - feature_dense = self.item_features.take(items.detach().cpu().numpy()).get_dense() - return torch.from_numpy(feature_dense) - @classmethod def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> tp.Optional[tpe.Self]: """ @@ -156,7 +151,19 @@ def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> warnings.warn(explanation) return None - return cls(item_cat_features, n_factors, dropout_rate) + emb_bag_inputs = torch.tensor(item_cat_features.values.indices, dtype=torch.long) + offsets = torch.tensor(item_cat_features.values.indptr, dtype=torch.long) + len_indexes = torch.diff(offsets, dim=0) + n_cat_features = len(item_cat_features.names) + + return cls( + emb_bag_inputs=emb_bag_inputs, + offsets=offsets[:-1], + len_indexes=len_indexes, + n_cat_features=n_cat_features, + n_factors=n_factors, + dropout_rate=dropout_rate, + ) class IdEmbeddingsItemNet(ItemNetBase): From 391267384ab8a14d85dc310f5e5434d9dffdf6ba Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Mon, 3 Feb 2025 20:14:06 +0300 Subject: [PATCH 2/5] fixed embedding_bag for 1 element, merge masks --- examples/sasrec_metrics_comp.ipynb | 859 +++++++++++++------------ rectools/models/nn/item_net.py | 33 +- rectools/models/nn/transformer_base.py | 2 +- tests/models/nn/test_item_net.py | 60 +- 4 files changed, 475 insertions(+), 479 deletions(-) diff --git a/examples/sasrec_metrics_comp.ipynb b/examples/sasrec_metrics_comp.ipynb index a144af86..d2131880 100644 --- a/examples/sasrec_metrics_comp.ipynb +++ b/examples/sasrec_metrics_comp.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -130,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -139,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -151,7 +151,7 @@ "genre_feature.columns = [\"id\", \"value\"]\n", "genre_feature[\"feature\"] = \"genre\"\n", "\n", - "items[\"director\"] = items[\"directors\"].str.lower().str.replace(\" \", \"\", regex=False).replace(\", \", \",\", regex=False).str.split(\",\")\n", + "items[\"director\"] = items[\"directors\"].str.lower().replace(\", \", \",\", regex=False).str.split(\",\")\n", "directors_feature = items[[\"item_id\", \"director\"]].explode(\"director\")\n", "directors_feature.columns = [\"id\", \"value\"]\n", "directors_feature[\"feature\"] = \"director\"\n", @@ -259,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -287,7 +287,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -315,7 +315,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2503c49e4f98450c91d2b1cfd0abbe8c", + "model_id": "06699b98353c4d2298bb1037f77fdf71", "version_major": 2, "version_minor": 0 }, @@ -337,29 +337,29 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 6min 12s, sys: 7.72 s, total: 6min 20s\n", - "Wall time: 6min 9s\n" + "CPU times: user 7min 41s, sys: 19.3 s, total: 8min\n", + "Wall time: 7min 50s\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", - "model.fit(dataset_no_features)" + "model.fit(dataset_item_features_genre_director)" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -383,7 +383,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cf1fbcbcf8af474db553209f7e279c63", + "model_id": "6a4509eead0042a28799087cefb68d47", "version_major": 2, "version_minor": 0 }, @@ -398,8 +398,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2min 27s, sys: 5.38 s, total: 2min 32s\n", - "Wall time: 19.2 s\n" + "CPU times: user 2min 43s, sys: 8.04 s, total: 2min 51s\n", + "Wall time: 25 s\n" ] } ], @@ -416,7 +416,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -427,7 +427,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -445,7 +445,7 @@ " 'model': 'softmax'}]" ] }, - "execution_count": 17, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -456,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -586,7 +586,7 @@ "[947050 rows x 4 columns]" ] }, - "execution_count": 18, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -604,7 +604,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -620,7 +620,7 @@ "32" ] }, - "execution_count": 20, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -633,7 +633,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -662,7 +662,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -690,7 +690,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5dc9b5b0a50c416d97913c9f4b1d8c70", + "model_id": "977a6b52aaf24f4fa0ccb05a7c9804ba", "version_major": 2, "version_minor": 0 }, @@ -712,17 +712,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 5min 41s, sys: 7.63 s, total: 5min 48s\n", - "Wall time: 5min 33s\n" + "CPU times: user 7min 41s, sys: 18 s, total: 7min 59s\n", + "Wall time: 7min 51s\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 22, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -734,7 +734,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -758,7 +758,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ca2c4093c5464072a1affed53e39757c", + "model_id": "4aa1c3be58ae41bb9aa81274121df9ed", "version_major": 2, "version_minor": 0 }, @@ -773,8 +773,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2min 53s, sys: 6.47 s, total: 2min 59s\n", - "Wall time: 21.1 s\n" + "CPU times: user 2min 42s, sys: 6.58 s, total: 2min 48s\n", + "Wall time: 24.4 s\n" ] } ], @@ -791,7 +791,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -802,7 +802,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -932,7 +932,7 @@ "[947050 rows x 4 columns]" ] }, - "execution_count": 25, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -950,7 +950,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -966,7 +966,7 @@ "32" ] }, - "execution_count": 27, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -979,7 +979,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -1009,7 +1009,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -1037,7 +1037,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "79f807c7c72945b7b63dbc911107b461", + "model_id": "ccf86bc484ed4bbe804a088a765ff054", "version_major": 2, "version_minor": 0 }, @@ -1059,17 +1059,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 1h 57min 29s, sys: 32.5 s, total: 1h 58min 1s\n", - "Wall time: 10min 24s\n" + "CPU times: user 2h 23min 5s, sys: 1min 12s, total: 2h 24min 18s\n", + "Wall time: 13min 45s\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 29, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1081,7 +1081,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -1105,7 +1105,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "145a981ae32e4efc83ff9759a207c331", + "model_id": "8e657a67ee204507a61fb6e6489248bd", "version_major": 2, "version_minor": 0 }, @@ -1120,8 +1120,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2min 34s, sys: 4.91 s, total: 2min 39s\n", - "Wall time: 18.6 s\n" + "CPU times: user 2min 55s, sys: 6.53 s, total: 3min 1s\n", + "Wall time: 23.3 s\n" ] } ], @@ -1138,7 +1138,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -1156,7 +1156,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -1184,7 +1184,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -1212,7 +1212,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7259973b494c4e7b9b6700affd991d2c", + "model_id": "a589239c20a943bf9c27af6118cb8f69", "version_major": 2, "version_minor": 0 }, @@ -1224,50 +1224,29 @@ "output_type": "display_data" }, { - "ename": "RuntimeError", - "evalue": "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m:1\u001b[0m\n", - "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:306\u001b[0m, in \u001b[0;36mModelBase.fit\u001b[0;34m(self, dataset, *args, **kwargs)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m: T, dataset: Dataset, \u001b[38;5;241m*\u001b[39margs: tp\u001b[38;5;241m.\u001b[39mAny, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: tp\u001b[38;5;241m.\u001b[39mAny) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 294\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 295\u001b[0m \u001b[38;5;124;03m Fit model.\u001b[39;00m\n\u001b[1;32m 296\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[38;5;124;03m self\u001b[39;00m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 306\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 307\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_fitted \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", - "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:737\u001b[0m, in \u001b[0;36mTransformerModelBase._fit\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 734\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_lightning_model(torch_model)\n\u001b[1;32m 736\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfit_trainer \u001b[38;5;241m=\u001b[39m deepcopy(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer)\n\u001b[0;32m--> 737\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_trainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloader\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/optim/optimizer.py:487\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 483\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 484\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 485\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 490\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/optim/optimizer.py:91\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 90\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 91\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 93\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/optim/adam.py:202\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 202\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 205\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:315\u001b[0m, in \u001b[0;36mSessionEncoderLightningModule.training_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 313\u001b[0m x, y, w \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m], batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m], batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myw\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msoftmax\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 315\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_full_catalog_logits\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 316\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_calc_softmax_loss(logits, y, w)\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBCE\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:367\u001b[0m, in \u001b[0;36mSessionEncoderLightningModule._get_full_catalog_logits\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 366\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_full_catalog_logits\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[0;32m--> 367\u001b[0m item_embs, session_embs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtorch_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 368\u001b[0m logits \u001b[38;5;241m=\u001b[39m session_embs \u001b[38;5;241m@\u001b[39m item_embs\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 369\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m logits\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git_repos/RecTools/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", - "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:230\u001b[0m, in \u001b[0;36mTransformerBasedSessionEncoder.forward\u001b[0;34m(self, sessions)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;124;03mForward pass to get item and session embeddings.\u001b[39;00m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;124;03mGet item embeddings.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[38;5;124;03m(torch.Tensor, torch.Tensor)\u001b[39;00m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 229\u001b[0m item_embs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitem_model\u001b[38;5;241m.\u001b[39mget_all_embeddings() \u001b[38;5;66;03m# [n_items + n_item_extra_tokens, n_factors]\u001b[39;00m\n\u001b[0;32m--> 230\u001b[0m session_embs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode_sessions\u001b[49m\u001b[43m(\u001b[49m\u001b[43msessions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mitem_embs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# [batch_size, session_max_len, n_factors]\u001b[39;00m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m item_embs, session_embs\n", - "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:205\u001b[0m, in \u001b[0;36mTransformerBasedSessionEncoder.encode_sessions\u001b[0;34m(self, sessions, item_embs)\u001b[0m\n\u001b[1;32m 203\u001b[0m key_padding_mask \u001b[38;5;241m=\u001b[39m sessions \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attn_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;66;03m# merge masks to prevent nan gradients for torch < 2.5.0\u001b[39;00m\n\u001b[0;32m--> 205\u001b[0m attn_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_merge_masks\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseqs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m key_padding_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 208\u001b[0m seqs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransformer_layers(seqs, timeline_mask, attn_mask, key_padding_mask)\n", - "File \u001b[0;32m/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_base.py:163\u001b[0m, in \u001b[0;36mTransformerBasedSessionEncoder._merge_masks\u001b[0;34m(self, attn_mask, key_padding_mask, query)\u001b[0m\n\u001b[1;32m 155\u001b[0m attn_mask_expanded \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_convert_mask_to_float(attn_mask, query) \u001b[38;5;66;03m# [session_max_len, session_max_len]\u001b[39;00m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m1\u001b[39m, seq_len, seq_len)\n\u001b[1;32m 158\u001b[0m \u001b[38;5;241m.\u001b[39mexpand(batch_size, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 159\u001b[0m ) \u001b[38;5;66;03m# [batch_size, session_max_len, session_max_len]\u001b[39;00m\n\u001b[1;32m 161\u001b[0m merged_mask \u001b[38;5;241m=\u001b[39m attn_mask_expanded \u001b[38;5;241m+\u001b[39m key_padding_mask_expanded\n\u001b[1;32m 162\u001b[0m res \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 163\u001b[0m \u001b[43mmerged_mask\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseq_len\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseq_len\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpand\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_heads\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 165\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseq_len\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseq_len\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m ) \u001b[38;5;66;03m# [batch_size * n_heads, session_max_len, session_max_len]\u001b[39;00m\n\u001b[1;32m 167\u001b[0m torch\u001b[38;5;241m.\u001b[39mdiagonal(res, dim1\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, dim2\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39mzero_()\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m res\n", - "\u001b[0;31mRuntimeError\u001b[0m: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead." + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 7min 56s, sys: 15.9 s, total: 8min 12s\n", + "Wall time: 7min 53s\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -1277,32 +1256,31 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", + "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", " Model `` doesn't support recommendations for cold users,\n", " but some of given users are cold: they are not in the `dataset.user_id_map`\n", " \n", " warnings.warn(explanation)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" + "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bdf51affb5944bc1a4630b17aefbc68f", + "model_id": "b2a472b585d44beb930aca7c172982cc", "version_major": 2, "version_minor": 0 }, @@ -1317,8 +1295,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 25.6 s, sys: 3.04 s, total: 28.6 s\n", - "Wall time: 19.8 s\n" + "CPU times: user 2min 51s, sys: 6.39 s, total: 2min 58s\n", + "Wall time: 23.2 s\n" ] } ], @@ -1335,7 +1313,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -1344,7 +1322,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -1354,7 +1332,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -1403,6 +1381,18 @@ " \n", " \n", " \n", + " softmax_padding_mask\n", + " 0.049449\n", + " 0.083399\n", + " 0.092535\n", + " 3.703682\n", + " 4.303675\n", + " 4.937142\n", + " 0.000960\n", + " 0.000727\n", + " 0.000670\n", + " \n", + " \n", " softmax\n", " 0.048466\n", " 0.081695\n", @@ -1443,20 +1433,22 @@ "" ], "text/plain": [ - " MAP@1 MAP@5 MAP@10 MIUF@1 MIUF@5 MIUF@10 \\\n", - "model \n", - "softmax 0.048466 0.081695 0.090704 3.871426 4.573069 5.159742 \n", - "gBCE 0.040848 0.072356 0.080166 2.332397 3.093763 3.942205 \n", - "bce 0.027035 0.051244 0.059080 3.882081 4.384314 4.734298 \n", + " MAP@1 MAP@5 MAP@10 MIUF@1 MIUF@5 \\\n", + "model \n", + "softmax_padding_mask 0.049449 0.083399 0.092535 3.703682 4.303675 \n", + "softmax 0.048466 0.081695 0.090704 3.871426 4.573069 \n", + "gBCE 0.040848 0.072356 0.080166 2.332397 3.093763 \n", + "bce 0.027035 0.051244 0.059080 3.882081 4.384314 \n", "\n", - " Serendipity@1 Serendipity@5 Serendipity@10 \n", - "model \n", - "softmax 0.001117 0.000865 0.000763 \n", - "gBCE 0.000103 0.000118 0.000134 \n", - "bce 0.000104 0.000121 0.000131 " + " MIUF@10 Serendipity@1 Serendipity@5 Serendipity@10 \n", + "model \n", + "softmax_padding_mask 4.937142 0.000960 0.000727 0.000670 \n", + "softmax 5.159742 0.001117 0.000865 0.000763 \n", + "gBCE 3.942205 0.000103 0.000118 0.000134 \n", + "bce 4.734298 0.000104 0.000121 0.000131 " ] }, - "execution_count": 36, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1479,7 +1471,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -1506,7 +1498,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 40, "metadata": {}, "outputs": [ { @@ -1534,7 +1526,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "14f1a094052c406ab577da3a5fef8f69", + "model_id": "115180caaade47eb91e2121666f5bb31", "version_major": 2, "version_minor": 0 }, @@ -1556,17 +1548,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 6min 18s, sys: 7.6 s, total: 6min 26s\n", - "Wall time: 6min 11s\n" + "CPU times: user 7min 45s, sys: 17.2 s, total: 8min 2s\n", + "Wall time: 7min 51s\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 39, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -1578,7 +1570,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 41, "metadata": {}, "outputs": [ { @@ -1602,7 +1594,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a11ad50979364f5985370f3942eb9c8c", + "model_id": "31ef580c7fa3448d9f0e6122149377a6", "version_major": 2, "version_minor": 0 }, @@ -1617,8 +1609,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2min 35s, sys: 4.93 s, total: 2min 40s\n", - "Wall time: 23.4 s\n" + "CPU times: user 2min 43s, sys: 5.55 s, total: 2min 49s\n", + "Wall time: 21.7 s\n" ] } ], @@ -1635,7 +1627,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -1653,7 +1645,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 44, "metadata": {}, "outputs": [ { @@ -1680,7 +1672,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 45, "metadata": {}, "outputs": [ { @@ -1694,12 +1686,12 @@ "\n", " | Name | Type | Params | Mode \n", "-----------------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 3.4 M | train\n", + "0 | torch_model | TransformerBasedSessionEncoder | 3.5 M | train\n", "-----------------------------------------------------------------------\n", - "3.4 M Trainable params\n", + "3.5 M Trainable params\n", "0 Non-trainable params\n", - "3.4 M Total params\n", - "13.621 Total estimated model params size (MB)\n", + "3.5 M Total params\n", + "13.903 Total estimated model params size (MB)\n", "39 Modules in train mode\n", "0 Modules in eval mode\n", "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" @@ -1708,7 +1700,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0184c3daa298402da6b3ec16aa38a5dc", + "model_id": "8a9ca68d8e6d40838c608f71ba16b50d", "version_major": 2, "version_minor": 0 }, @@ -1729,10 +1721,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 44, + "execution_count": 45, "metadata": {}, "output_type": "execute_result" } @@ -1744,7 +1736,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 46, "metadata": {}, "outputs": [ { @@ -1768,7 +1760,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2b68d1c5bf7543c599a45b39190b1fbc", + "model_id": "970a1097f4d947f3b4d841eeed2e7a61", "version_major": 2, "version_minor": 0 }, @@ -1783,8 +1775,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2min 37s, sys: 8.99 s, total: 2min 46s\n", - "Wall time: 21.8 s\n" + "CPU times: user 2min 43s, sys: 6.03 s, total: 2min 49s\n", + "Wall time: 21 s\n" ] } ], @@ -1801,7 +1793,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -1819,7 +1811,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 49, "metadata": {}, "outputs": [ { @@ -1846,7 +1838,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -1865,7 +1857,7 @@ "2.0 M Trainable params\n", "0 Non-trainable params\n", "2.0 M Total params\n", - "7.832 Total estimated model params size (MB)\n", + "8.113 Total estimated model params size (MB)\n", "36 Modules in train mode\n", "0 Modules in eval mode\n", "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" @@ -1874,7 +1866,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "54489a8a6a27455bb0d2b7c9588f0ec8", + "model_id": "6d6507d9857d40439852d93b92ec490f", "version_major": 2, "version_minor": 0 }, @@ -1895,10 +1887,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 53, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } @@ -1910,7 +1902,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 51, "metadata": {}, "outputs": [ { @@ -1934,7 +1926,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ec97ecaa85eb4c3ead46f112d1acf090", + "model_id": "aee07340143145f080a9a1abc19dc6fb", "version_major": 2, "version_minor": 0 }, @@ -1949,8 +1941,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2min 31s, sys: 8.96 s, total: 2min 40s\n", - "Wall time: 18.7 s\n" + "CPU times: user 4min 34s, sys: 5.32 s, total: 4min 39s\n", + "Wall time: 23 s\n" ] } ], @@ -1967,7 +1959,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -1978,7 +1970,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 53, "metadata": {}, "outputs": [ { @@ -2014,39 +2006,49 @@ " 'Serendipity@5': 0.00011795153034776448,\n", " 'Serendipity@10': 0.00013442022189753792,\n", " 'model': 'gBCE'},\n", - " {'MAP@1': 0.048147530079556855,\n", - " 'MAP@5': 0.08174752632088274,\n", - " 'MAP@10': 0.09082809872497208,\n", - " 'MIUF@1': 18.824620072061013,\n", - " 'MIUF@5': 18.824620072061013,\n", - " 'MIUF@10': 18.824620072061013,\n", - " 'Serendipity@1': 0.09930837864949052,\n", - " 'Serendipity@5': 0.06018338155520556,\n", - " 'Serendipity@10': 0.04419324756562254,\n", + " {'MAP@1': 0.04944927980383728,\n", + " 'MAP@5': 0.08339905218390668,\n", + " 'MAP@10': 0.09253457436132577,\n", + " 'MIUF@1': 3.7036816495680407,\n", + " 'MIUF@5': 4.303675387547709,\n", + " 'MIUF@10': 4.937142346053472,\n", + " 'Serendipity@1': 0.0009600473125589197,\n", + " 'Serendipity@5': 0.0007267346226828757,\n", + " 'Serendipity@10': 0.0006702334199993685,\n", + " 'model': 'softmax_padding_mask'},\n", + " {'MAP@1': 0.048232041728511546,\n", + " 'MAP@5': 0.08228749067398929,\n", + " 'MAP@10': 0.09149502267291341,\n", + " 'MIUF@1': 3.63326387297099,\n", + " 'MIUF@5': 4.308850581843274,\n", + " 'MIUF@10': 4.966089748691736,\n", + " 'Serendipity@1': 0.0010093499068625526,\n", + " 'Serendipity@5': 0.0008033309987059076,\n", + " 'Serendipity@10': 0.0007285049958172183,\n", " 'model': 'sasrec_ids'},\n", - " {'MAP@1': 0.04766175619468918,\n", - " 'MAP@5': 0.08225834198336934,\n", - " 'MAP@10': 0.09147359309311318,\n", - " 'MIUF@1': 3.9415828408413325,\n", - " 'MIUF@5': 4.572004200795584,\n", - " 'MIUF@10': 5.181503181696694,\n", - " 'Serendipity@1': 0.0012604074213600975,\n", - " 'Serendipity@5': 0.0009317855973533245,\n", - " 'Serendipity@10': 0.0008227161922729454,\n", + " {'MAP@1': 0.046980831099691825,\n", + " 'MAP@5': 0.08064517854487308,\n", + " 'MAP@10': 0.09000793187344863,\n", + " 'MIUF@1': 4.330346818203394,\n", + " 'MIUF@5': 4.974336941917242,\n", + " 'MIUF@10': 5.5104503077061855,\n", + " 'Serendipity@1': 0.0015174010466626042,\n", + " 'Serendipity@5': 0.00110907171190758,\n", + " 'Serendipity@10': 0.0009659832365762532,\n", " 'model': 'sasrec_id_and_cat_features'},\n", - " {'MAP@1': 0.04310550975326949,\n", - " 'MAP@5': 0.07036672111056717,\n", - " 'MAP@10': 0.07822710009694937,\n", - " 'MIUF@1': 4.18466637016627,\n", - " 'MIUF@5': 5.596602216823304,\n", - " 'MIUF@10': 6.135927367440625,\n", - " 'Serendipity@1': 0.0010273693092966185,\n", - " 'Serendipity@5': 0.00088834477237887,\n", - " 'Serendipity@10': 0.0007727660052644417,\n", + " {'MAP@1': 0.04404827524837389,\n", + " 'MAP@5': 0.07166794786711866,\n", + " 'MAP@10': 0.07951454257314965,\n", + " 'MIUF@1': 4.009884724409452,\n", + " 'MIUF@5': 5.5158724169279,\n", + " 'MIUF@10': 6.006320604496849,\n", + " 'Serendipity@1': 0.0009655015112846868,\n", + " 'Serendipity@5': 0.0008125160935052519,\n", + " 'Serendipity@10': 0.000715293506917442,\n", " 'model': 'sasrec_cat_features'}]" ] }, - "execution_count": 56, + "execution_count": 53, "metadata": {}, "output_type": "execute_result" } @@ -2057,7 +2059,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 54, "metadata": {}, "outputs": [ { @@ -2106,28 +2108,28 @@ " \n", " \n", " \n", - " sasrec_id_and_cat_features\n", - " 0.047662\n", - " 0.082258\n", - " 0.091474\n", - " 3.941583\n", - " 4.572004\n", - " 5.181503\n", - " 0.001260\n", - " 0.000932\n", - " 0.000823\n", + " softmax_padding_mask\n", + " 0.049449\n", + " 0.083399\n", + " 0.092535\n", + " 3.703682\n", + " 4.303675\n", + " 4.937142\n", + " 0.000960\n", + " 0.000727\n", + " 0.000670\n", " \n", " \n", " sasrec_ids\n", - " 0.048148\n", - " 0.081748\n", - " 0.090828\n", - " 18.824620\n", - " 18.824620\n", - " 18.824620\n", - " 0.099308\n", - " 0.060183\n", - " 0.044193\n", + " 0.048232\n", + " 0.082287\n", + " 0.091495\n", + " 3.633264\n", + " 4.308851\n", + " 4.966090\n", + " 0.001009\n", + " 0.000803\n", + " 0.000729\n", " \n", " \n", " softmax\n", @@ -2142,6 +2144,18 @@ " 0.000763\n", " \n", " \n", + " sasrec_id_and_cat_features\n", + " 0.046981\n", + " 0.080645\n", + " 0.090008\n", + " 4.330347\n", + " 4.974337\n", + " 5.510450\n", + " 0.001517\n", + " 0.001109\n", + " 0.000966\n", + " \n", + " \n", " gBCE\n", " 0.040848\n", " 0.072356\n", @@ -2155,15 +2169,15 @@ " \n", " \n", " sasrec_cat_features\n", - " 0.043106\n", - " 0.070367\n", - " 0.078227\n", - " 4.184666\n", - " 5.596602\n", - " 6.135927\n", - " 0.001027\n", - " 0.000888\n", - " 0.000773\n", + " 0.044048\n", + " 0.071668\n", + " 0.079515\n", + " 4.009885\n", + " 5.515872\n", + " 6.006321\n", + " 0.000966\n", + " 0.000813\n", + " 0.000715\n", " \n", " \n", " bce\n", @@ -2182,35 +2196,38 @@ "" ], "text/plain": [ - " MAP@1 MAP@5 MAP@10 MIUF@1 \\\n", - "model \n", - "sasrec_id_and_cat_features 0.047662 0.082258 0.091474 3.941583 \n", - "sasrec_ids 0.048148 0.081748 0.090828 18.824620 \n", - "softmax 0.048466 0.081695 0.090704 3.871426 \n", - "gBCE 0.040848 0.072356 0.080166 2.332397 \n", - "sasrec_cat_features 0.043106 0.070367 0.078227 4.184666 \n", - "bce 0.027035 0.051244 0.059080 3.882081 \n", + " MAP@1 MAP@5 MAP@10 MIUF@1 MIUF@5 \\\n", + "model \n", + "softmax_padding_mask 0.049449 0.083399 0.092535 3.703682 4.303675 \n", + "sasrec_ids 0.048232 0.082287 0.091495 3.633264 4.308851 \n", + "softmax 0.048466 0.081695 0.090704 3.871426 4.573069 \n", + "sasrec_id_and_cat_features 0.046981 0.080645 0.090008 4.330347 4.974337 \n", + "gBCE 0.040848 0.072356 0.080166 2.332397 3.093763 \n", + "sasrec_cat_features 0.044048 0.071668 0.079515 4.009885 5.515872 \n", + "bce 0.027035 0.051244 0.059080 3.882081 4.384314 \n", "\n", - " MIUF@5 MIUF@10 Serendipity@1 \\\n", - "model \n", - "sasrec_id_and_cat_features 4.572004 5.181503 0.001260 \n", - "sasrec_ids 18.824620 18.824620 0.099308 \n", - "softmax 4.573069 5.159742 0.001117 \n", - "gBCE 3.093763 3.942205 0.000103 \n", - "sasrec_cat_features 5.596602 6.135927 0.001027 \n", - "bce 4.384314 4.734298 0.000104 \n", + " MIUF@10 Serendipity@1 Serendipity@5 \\\n", + "model \n", + "softmax_padding_mask 4.937142 0.000960 0.000727 \n", + "sasrec_ids 4.966090 0.001009 0.000803 \n", + "softmax 5.159742 0.001117 0.000865 \n", + "sasrec_id_and_cat_features 5.510450 0.001517 0.001109 \n", + "gBCE 3.942205 0.000103 0.000118 \n", + "sasrec_cat_features 6.006321 0.000966 0.000813 \n", + "bce 4.734298 0.000104 0.000121 \n", "\n", - " Serendipity@5 Serendipity@10 \n", - "model \n", - "sasrec_id_and_cat_features 0.000932 0.000823 \n", - "sasrec_ids 0.060183 0.044193 \n", - "softmax 0.000865 0.000763 \n", - "gBCE 0.000118 0.000134 \n", - "sasrec_cat_features 0.000888 0.000773 \n", - "bce 0.000121 0.000131 " + " Serendipity@10 \n", + "model \n", + "softmax_padding_mask 0.000670 \n", + "sasrec_ids 0.000729 \n", + "softmax 0.000763 \n", + "sasrec_id_and_cat_features 0.000966 \n", + "gBCE 0.000134 \n", + "sasrec_cat_features 0.000715 \n", + "bce 0.000131 " ] }, - "execution_count": 57, + "execution_count": 54, "metadata": {}, "output_type": "execute_result" } @@ -2233,7 +2250,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ @@ -2242,15 +2259,15 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 56, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 3.14 s, sys: 4.21 s, total: 7.35 s\n", - "Wall time: 1.15 s\n" + "CPU times: user 10.3 s, sys: 600 ms, total: 10.9 s\n", + "Wall time: 1.68 s\n" ] } ], @@ -2267,7 +2284,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 57, "metadata": {}, "outputs": [ { @@ -2316,196 +2333,196 @@ " 2\n", " 13865\n", " 6409\n", - " 0.628877\n", + " 0.559578\n", " 3\n", " \n", " \n", " 3\n", " 13865\n", " 142\n", - " 0.559630\n", + " 0.532108\n", " 4\n", " \n", " \n", " 4\n", " 13865\n", - " 2657\n", - " 0.514484\n", + " 4457\n", + " 0.531360\n", " 5\n", " \n", " \n", " 5\n", " 13865\n", - " 4457\n", - " 0.503537\n", + " 6809\n", + " 0.531058\n", " 6\n", " \n", " \n", " 6\n", " 13865\n", - " 15297\n", - " 0.500209\n", + " 10440\n", + " 0.522596\n", " 7\n", " \n", " \n", " 7\n", " 13865\n", - " 6809\n", - " 0.487185\n", + " 2657\n", + " 0.512829\n", " 8\n", " \n", " \n", " 8\n", " 13865\n", " 10772\n", - " 0.485932\n", + " 0.511374\n", " 9\n", " \n", " \n", " 9\n", " 13865\n", - " 10440\n", - " 0.473830\n", + " 15297\n", + " 0.477553\n", " 10\n", " \n", " \n", " 10\n", " 4457\n", " 14741\n", - " 0.725818\n", + " 0.743163\n", " 1\n", " \n", " \n", " 11\n", " 4457\n", " 12995\n", - " 0.658568\n", + " 0.644302\n", " 2\n", " \n", " \n", " 12\n", " 4457\n", - " 142\n", - " 0.646901\n", + " 6455\n", + " 0.629546\n", " 3\n", " \n", " \n", " 13\n", " 4457\n", - " 3935\n", - " 0.641437\n", + " 142\n", + " 0.624334\n", " 4\n", " \n", " \n", " 14\n", " 4457\n", - " 10772\n", - " 0.621477\n", + " 1287\n", + " 0.622713\n", " 5\n", " \n", " \n", " 15\n", " 4457\n", - " 1287\n", - " 0.611060\n", + " 274\n", + " 0.609102\n", " 6\n", " \n", " \n", " 16\n", " 4457\n", - " 3509\n", - " 0.610761\n", + " 10772\n", + " 0.604250\n", " 7\n", " \n", " \n", " 17\n", " 4457\n", - " 15464\n", - " 0.592197\n", + " 3935\n", + " 0.601068\n", " 8\n", " \n", " \n", " 18\n", " 4457\n", - " 274\n", - " 0.587578\n", + " 3509\n", + " 0.585827\n", " 9\n", " \n", " \n", " 19\n", " 4457\n", - " 6455\n", - " 0.583398\n", + " 9342\n", + " 0.575651\n", " 10\n", " \n", " \n", " 20\n", " 15297\n", " 142\n", - " 0.614521\n", + " 0.612435\n", " 1\n", " \n", " \n", " 21\n", " 15297\n", " 2657\n", - " 0.567724\n", + " 0.565547\n", " 2\n", " \n", " \n", " 22\n", " 15297\n", " 6809\n", - " 0.554094\n", + " 0.565121\n", " 3\n", " \n", " \n", " 23\n", " 15297\n", - " 10772\n", - " 0.539985\n", + " 10440\n", + " 0.562896\n", " 4\n", " \n", " \n", " 24\n", " 15297\n", - " 10440\n", - " 0.529680\n", + " 10772\n", + " 0.529499\n", " 5\n", " \n", " \n", " 25\n", " 15297\n", - " 14337\n", - " 0.514991\n", + " 3935\n", + " 0.520520\n", " 6\n", " \n", " \n", " 26\n", " 15297\n", - " 1844\n", - " 0.514991\n", + " 4457\n", + " 0.515862\n", " 7\n", " \n", " \n", " 27\n", " 15297\n", - " 4457\n", - " 0.507450\n", + " 14337\n", + " 0.506332\n", " 8\n", " \n", " \n", " 28\n", " 15297\n", - " 13865\n", - " 0.500209\n", + " 1844\n", + " 0.506332\n", " 9\n", " \n", " \n", " 29\n", " 15297\n", - " 7107\n", - " 0.500209\n", + " 8636\n", + " 0.503729\n", " 10\n", " \n", " \n", @@ -2516,37 +2533,37 @@ " target_item_id item_id score rank\n", "0 13865 11863 1.000000 1\n", "1 13865 7107 1.000000 2\n", - "2 13865 6409 0.628877 3\n", - "3 13865 142 0.559630 4\n", - "4 13865 2657 0.514484 5\n", - "5 13865 4457 0.503537 6\n", - "6 13865 15297 0.500209 7\n", - "7 13865 6809 0.487185 8\n", - "8 13865 10772 0.485932 9\n", - "9 13865 10440 0.473830 10\n", - "10 4457 14741 0.725818 1\n", - "11 4457 12995 0.658568 2\n", - "12 4457 142 0.646901 3\n", - "13 4457 3935 0.641437 4\n", - "14 4457 10772 0.621477 5\n", - "15 4457 1287 0.611060 6\n", - "16 4457 3509 0.610761 7\n", - "17 4457 15464 0.592197 8\n", - "18 4457 274 0.587578 9\n", - "19 4457 6455 0.583398 10\n", - "20 15297 142 0.614521 1\n", - "21 15297 2657 0.567724 2\n", - "22 15297 6809 0.554094 3\n", - "23 15297 10772 0.539985 4\n", - "24 15297 10440 0.529680 5\n", - "25 15297 14337 0.514991 6\n", - "26 15297 1844 0.514991 7\n", - "27 15297 4457 0.507450 8\n", - "28 15297 13865 0.500209 9\n", - "29 15297 7107 0.500209 10" + "2 13865 6409 0.559578 3\n", + "3 13865 142 0.532108 4\n", + "4 13865 4457 0.531360 5\n", + "5 13865 6809 0.531058 6\n", + "6 13865 10440 0.522596 7\n", + "7 13865 2657 0.512829 8\n", + "8 13865 10772 0.511374 9\n", + "9 13865 15297 0.477553 10\n", + "10 4457 14741 0.743163 1\n", + "11 4457 12995 0.644302 2\n", + "12 4457 6455 0.629546 3\n", + "13 4457 142 0.624334 4\n", + "14 4457 1287 0.622713 5\n", + "15 4457 274 0.609102 6\n", + "16 4457 10772 0.604250 7\n", + "17 4457 3935 0.601068 8\n", + "18 4457 3509 0.585827 9\n", + "19 4457 9342 0.575651 10\n", + "20 15297 142 0.612435 1\n", + "21 15297 2657 0.565547 2\n", + "22 15297 6809 0.565121 3\n", + "23 15297 10440 0.562896 4\n", + "24 15297 10772 0.529499 5\n", + "25 15297 3935 0.520520 6\n", + "26 15297 4457 0.515862 7\n", + "27 15297 14337 0.506332 8\n", + "28 15297 1844 0.506332 9\n", + "29 15297 8636 0.503729 10" ] }, - "execution_count": 60, + "execution_count": 57, "metadata": {}, "output_type": "execute_result" } @@ -2557,7 +2574,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 58, "metadata": {}, "outputs": [ { @@ -2609,7 +2626,7 @@ " 2\n", " 13865\n", " 6409\n", - " 0.628877\n", + " 0.559578\n", " 3\n", " Особо опасен\n", " \n", @@ -2617,63 +2634,63 @@ " 3\n", " 13865\n", " 142\n", - " 0.559630\n", + " 0.532108\n", " 4\n", " Маша\n", " \n", " \n", " 4\n", " 13865\n", - " 2657\n", - " 0.514484\n", + " 4457\n", + " 0.531360\n", " 5\n", - " Подслушано\n", + " 2067: Петля времени\n", " \n", " \n", " 5\n", " 13865\n", - " 4457\n", - " 0.503537\n", + " 6809\n", + " 0.531058\n", " 6\n", - " 2067: Петля времени\n", + " Дуров\n", " \n", " \n", " 6\n", " 13865\n", - " 15297\n", - " 0.500209\n", + " 10440\n", + " 0.522596\n", " 7\n", - " Клиника счастья\n", + " Хрустальный\n", " \n", " \n", " 7\n", " 13865\n", - " 6809\n", - " 0.487185\n", + " 2657\n", + " 0.512829\n", " 8\n", - " Дуров\n", + " Подслушано\n", " \n", " \n", " 8\n", " 13865\n", " 10772\n", - " 0.485932\n", + " 0.511374\n", " 9\n", " Зелёная книга\n", " \n", " \n", " 9\n", " 13865\n", - " 10440\n", - " 0.473830\n", + " 15297\n", + " 0.477553\n", " 10\n", - " Хрустальный\n", + " Клиника счастья\n", " \n", " \n", " 10\n", " 4457\n", " 14741\n", - " 0.725818\n", + " 0.743163\n", " 1\n", " Цвет из иных миров\n", " \n", @@ -2681,79 +2698,79 @@ " 11\n", " 4457\n", " 12995\n", - " 0.658568\n", + " 0.644302\n", " 2\n", " Восемь сотен\n", " \n", " \n", " 12\n", " 4457\n", - " 142\n", - " 0.646901\n", + " 6455\n", + " 0.629546\n", " 3\n", - " Маша\n", + " Альфа\n", " \n", " \n", " 13\n", " 4457\n", - " 3935\n", - " 0.641437\n", + " 142\n", + " 0.624334\n", " 4\n", - " Бывшая с того света\n", + " Маша\n", " \n", " \n", " 14\n", " 4457\n", - " 10772\n", - " 0.621477\n", + " 1287\n", + " 0.622713\n", " 5\n", - " Зелёная книга\n", + " Терминатор: Тёмные судьбы\n", " \n", " \n", " 15\n", " 4457\n", - " 1287\n", - " 0.611060\n", + " 274\n", + " 0.609102\n", " 6\n", - " Терминатор: Тёмные судьбы\n", + " Логан\n", " \n", " \n", " 16\n", " 4457\n", - " 3509\n", - " 0.610761\n", + " 10772\n", + " 0.604250\n", " 7\n", - " Комната желаний\n", + " Зелёная книга\n", " \n", " \n", " 17\n", " 4457\n", - " 15464\n", - " 0.592197\n", + " 3935\n", + " 0.601068\n", " 8\n", - " Апгрейд\n", + " Бывшая с того света\n", " \n", " \n", " 18\n", " 4457\n", - " 274\n", - " 0.587578\n", + " 3509\n", + " 0.585827\n", " 9\n", - " Логан\n", + " Комната желаний\n", " \n", " \n", " 19\n", " 4457\n", - " 6455\n", - " 0.583398\n", + " 9342\n", + " 0.575651\n", " 10\n", - " Альфа\n", + " Дэдпул\n", " \n", " \n", " 20\n", " 15297\n", " 142\n", - " 0.614521\n", + " 0.612435\n", " 1\n", " Маша\n", " \n", @@ -2761,7 +2778,7 @@ " 21\n", " 15297\n", " 2657\n", - " 0.567724\n", + " 0.565547\n", " 2\n", " Подслушано\n", " \n", @@ -2769,65 +2786,65 @@ " 22\n", " 15297\n", " 6809\n", - " 0.554094\n", + " 0.565121\n", " 3\n", " Дуров\n", " \n", " \n", " 23\n", " 15297\n", - " 10772\n", - " 0.539985\n", + " 10440\n", + " 0.562896\n", " 4\n", - " Зелёная книга\n", + " Хрустальный\n", " \n", " \n", " 24\n", " 15297\n", - " 10440\n", - " 0.529680\n", + " 10772\n", + " 0.529499\n", " 5\n", - " Хрустальный\n", + " Зелёная книга\n", " \n", " \n", " 25\n", " 15297\n", - " 14337\n", - " 0.514991\n", + " 3935\n", + " 0.520520\n", " 6\n", - " [4К] Аферистка\n", + " Бывшая с того света\n", " \n", " \n", " 26\n", " 15297\n", - " 1844\n", - " 0.514991\n", + " 4457\n", + " 0.515862\n", " 7\n", - " Аферистка\n", + " 2067: Петля времени\n", " \n", " \n", " 27\n", " 15297\n", - " 4457\n", - " 0.507450\n", + " 14337\n", + " 0.506332\n", " 8\n", - " 2067: Петля времени\n", + " [4К] Аферистка\n", " \n", " \n", " 28\n", " 15297\n", - " 13865\n", - " 0.500209\n", + " 1844\n", + " 0.506332\n", " 9\n", - " Девятаев\n", + " Аферистка\n", " \n", " \n", " 29\n", " 15297\n", - " 7107\n", - " 0.500209\n", + " 8636\n", + " 0.503729\n", " 10\n", - " Девятаев\n", + " Белый снег\n", " \n", " \n", "\n", @@ -2837,37 +2854,37 @@ " target_item_id item_id score rank title\n", "0 13865 11863 1.000000 1 Девятаев - сериал\n", "1 13865 7107 1.000000 2 Девятаев\n", - "2 13865 6409 0.628877 3 Особо опасен\n", - "3 13865 142 0.559630 4 Маша\n", - "4 13865 2657 0.514484 5 Подслушано\n", - "5 13865 4457 0.503537 6 2067: Петля времени\n", - "6 13865 15297 0.500209 7 Клиника счастья\n", - "7 13865 6809 0.487185 8 Дуров\n", - "8 13865 10772 0.485932 9 Зелёная книга\n", - "9 13865 10440 0.473830 10 Хрустальный\n", - "10 4457 14741 0.725818 1 Цвет из иных миров\n", - "11 4457 12995 0.658568 2 Восемь сотен\n", - "12 4457 142 0.646901 3 Маша\n", - "13 4457 3935 0.641437 4 Бывшая с того света\n", - "14 4457 10772 0.621477 5 Зелёная книга\n", - "15 4457 1287 0.611060 6 Терминатор: Тёмные судьбы\n", - "16 4457 3509 0.610761 7 Комната желаний\n", - "17 4457 15464 0.592197 8 Апгрейд\n", - "18 4457 274 0.587578 9 Логан\n", - "19 4457 6455 0.583398 10 Альфа\n", - "20 15297 142 0.614521 1 Маша\n", - "21 15297 2657 0.567724 2 Подслушано\n", - "22 15297 6809 0.554094 3 Дуров\n", - "23 15297 10772 0.539985 4 Зелёная книга\n", - "24 15297 10440 0.529680 5 Хрустальный\n", - "25 15297 14337 0.514991 6 [4К] Аферистка\n", - "26 15297 1844 0.514991 7 Аферистка\n", - "27 15297 4457 0.507450 8 2067: Петля времени\n", - "28 15297 13865 0.500209 9 Девятаев\n", - "29 15297 7107 0.500209 10 Девятаев" + "2 13865 6409 0.559578 3 Особо опасен\n", + "3 13865 142 0.532108 4 Маша\n", + "4 13865 4457 0.531360 5 2067: Петля времени\n", + "5 13865 6809 0.531058 6 Дуров\n", + "6 13865 10440 0.522596 7 Хрустальный\n", + "7 13865 2657 0.512829 8 Подслушано\n", + "8 13865 10772 0.511374 9 Зелёная книга\n", + "9 13865 15297 0.477553 10 Клиника счастья\n", + "10 4457 14741 0.743163 1 Цвет из иных миров\n", + "11 4457 12995 0.644302 2 Восемь сотен\n", + "12 4457 6455 0.629546 3 Альфа\n", + "13 4457 142 0.624334 4 Маша\n", + "14 4457 1287 0.622713 5 Терминатор: Тёмные судьбы\n", + "15 4457 274 0.609102 6 Логан\n", + "16 4457 10772 0.604250 7 Зелёная книга\n", + "17 4457 3935 0.601068 8 Бывшая с того света\n", + "18 4457 3509 0.585827 9 Комната желаний\n", + "19 4457 9342 0.575651 10 Дэдпул\n", + "20 15297 142 0.612435 1 Маша\n", + "21 15297 2657 0.565547 2 Подслушано\n", + "22 15297 6809 0.565121 3 Дуров\n", + "23 15297 10440 0.562896 4 Хрустальный\n", + "24 15297 10772 0.529499 5 Зелёная книга\n", + "25 15297 3935 0.520520 6 Бывшая с того света\n", + "26 15297 4457 0.515862 7 2067: Петля времени\n", + "27 15297 14337 0.506332 8 [4К] Аферистка\n", + "28 15297 1844 0.506332 9 Аферистка\n", + "29 15297 8636 0.503729 10 Белый снег" ] }, - "execution_count": 61, + "execution_count": 58, "metadata": {}, "output_type": "execute_result" } diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index b9c358b6..f3e2b381 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -62,22 +62,22 @@ class CatFeaturesItemNet(ItemNetBase): def __init__( self, emb_bag_inputs: torch.Tensor, - len_indexes: torch.Tensor, + input_lengths: torch.Tensor, offsets: torch.Tensor, - n_cat_features: int, + n_cat_feature_values: int, n_factors: int, dropout_rate: float, ): super().__init__() - self.n_cat_features = n_cat_features - self.embedding_bag = nn.EmbeddingBag(num_embeddings=n_cat_features, embedding_dim=n_factors, padding_idx=0) + self.n_cat_feature_values = n_cat_feature_values + self.embedding_bag = nn.EmbeddingBag(num_embeddings=n_cat_feature_values, embedding_dim=n_factors) self.drop_layer = nn.Dropout(dropout_rate) self.register_buffer("offsets", offsets) - self.register_buffer("emb_bag_indexes", torch.arange(len(emb_bag_inputs))) + self.register_buffer("length_range", torch.arange(input_lengths.max().item())) self.register_buffer("emb_bag_inputs", emb_bag_inputs) - self.register_buffer("len_indexes", len_indexes) + self.register_buffer("input_lengths", input_lengths) def forward(self, items: torch.Tensor) -> torch.Tensor: """ @@ -100,19 +100,14 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: def get_item_inputs_offsets(self, items: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Get categorical item features and offsets for `items`.""" - item_indexes = self.offsets[items].unsqueeze(-1) + self.emb_bag_indexes - length_mask = self.emb_bag_indexes < self.len_indexes[items].unsqueeze(-1) - item_emb_bag_inputs = self.emb_bag_inputs[item_indexes[length_mask]].squeeze(-1) + item_indexes = self.offsets[items].unsqueeze(-1) + self.length_range + length_mask = self.length_range < self.input_lengths[items].unsqueeze(-1) + item_emb_bag_inputs = self.emb_bag_inputs[item_indexes[length_mask]] item_offsets = torch.cat( - (torch.tensor([0], device=self.device), torch.cumsum(self.len_indexes[items], dim=0)[:-1]) + (torch.tensor([0], device=self.device), torch.cumsum(self.input_lengths[items], dim=0)[:-1]) ) return item_emb_bag_inputs, item_offsets - @property - def feature_catalog(self) -> torch.Tensor: - """Return tensor with elements in range [0, n_cat_features).""" - return torch.arange(0, self.n_cat_features) - @classmethod def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> tp.Optional[tpe.Self]: """ @@ -153,14 +148,14 @@ def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> emb_bag_inputs = torch.tensor(item_cat_features.values.indices, dtype=torch.long) offsets = torch.tensor(item_cat_features.values.indptr, dtype=torch.long) - len_indexes = torch.diff(offsets, dim=0) - n_cat_features = len(item_cat_features.names) + input_lengths = torch.diff(offsets, dim=0) + n_cat_feature_values = len(item_cat_features.names) return cls( emb_bag_inputs=emb_bag_inputs, offsets=offsets[:-1], - len_indexes=len_indexes, - n_cat_features=n_cat_features, + input_lengths=input_lengths, + n_cat_feature_values=n_cat_feature_values, n_factors=n_factors, dropout_rate=dropout_rate, ) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index f3d4c14d..738fdb81 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -162,7 +162,7 @@ def _merge_masks( res = ( merged_mask.view(batch_size, 1, seq_len, seq_len) .expand(-1, self.n_heads, -1, -1) - .view(-1, seq_len, seq_len) + .reshape(-1, seq_len, seq_len) ) # [batch_size * n_heads, session_max_len, session_max_len] torch.diagonal(res, dim1=1, dim2=2).zero_() return res diff --git a/tests/models/nn/test_item_net.py b/tests/models/nn/test_item_net.py index 5984a5da..ad8eb88c 100644 --- a/tests/models/nn/test_item_net.py +++ b/tests/models/nn/test_item_net.py @@ -22,9 +22,7 @@ from rectools.columns import Columns from rectools.dataset import Dataset -from rectools.dataset.features import SparseFeatures from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase, ItemNetConstructor -from tests.testing_utils import assert_feature_set_equal from ..data import DATASET, INTERACTIONS @@ -69,8 +67,6 @@ def _seed_everything(self) -> None: def dataset_item_features(self) -> Dataset: item_features = pd.DataFrame( [ - [11, "f1", "f1val1"], - [11, "f2", "f2val1"], [12, "f1", "f1val1"], [12, "f2", "f2val2"], [13, "f1", "f1val1"], @@ -83,7 +79,6 @@ def dataset_item_features(self) -> Dataset: [17, "f2", "f2val3"], [16, "f1", "f1val2"], [16, "f2", "f2val3"], - [11, "f3", 0], [12, "f3", 1], [13, "f3", 2], [14, "f3", 3], @@ -100,34 +95,19 @@ def dataset_item_features(self) -> Dataset: ) return ds - def test_feature_catalog(self, dataset_item_features: Dataset) -> None: - cat_item_embeddings = CatFeaturesItemNet.from_dataset(dataset_item_features, n_factors=5, dropout_rate=0.5) - assert isinstance(cat_item_embeddings, CatFeaturesItemNet) - expected_feature_catalog = torch.arange(0, cat_item_embeddings.n_cat_features) - assert torch.equal(cat_item_embeddings.feature_catalog, expected_feature_catalog) - - def test_get_dense_item_features(self, dataset_item_features: Dataset) -> None: + def test_get_item_inputs_offsets(self, dataset_item_features: Dataset) -> None: items = torch.from_numpy( dataset_item_features.item_id_map.convert_to_internal(INTERACTIONS[Columns.Item].unique()) - ) + )[:-1] cat_item_embeddings = CatFeaturesItemNet.from_dataset(dataset_item_features, n_factors=5, dropout_rate=0.5) assert isinstance(cat_item_embeddings, CatFeaturesItemNet) - actual_feature_dense = cat_item_embeddings.get_dense_item_features(items) - expected_feature_dense = torch.tensor( - [ - [1, 0, 1, 0, 0], - [1, 0, 0, 1, 0], - [0, 1, 1, 0, 0], - [1, 0, 0, 0, 1], - [0, 1, 0, 1, 0], - [0, 1, 0, 0, 1], - ], - dtype=torch.float, - ) - - assert torch.equal(actual_feature_dense, expected_feature_dense) + actual_item_emb_bag_inputs, actual_item_offsets = cat_item_embeddings.get_item_inputs_offsets(items) + expected_item_emb_bag_inputs = torch.tensor([0, 2, 1, 4, 0, 3, 1, 2]) + expected_item_offsets = torch.tensor([0, 0, 2, 4, 6]) + assert torch.equal(actual_item_emb_bag_inputs, expected_item_emb_bag_inputs) + assert torch.equal(actual_item_offsets, expected_item_offsets) @pytest.mark.parametrize("n_factors", (10, 100)) def test_create_from_dataset(self, n_factors: int, dataset_item_features: Dataset) -> None: @@ -137,20 +117,24 @@ def test_create_from_dataset(self, n_factors: int, dataset_item_features: Datase assert isinstance(cat_item_embeddings, CatFeaturesItemNet) - actual_item_features = cat_item_embeddings.item_features - actual_n_items = cat_item_embeddings.n_items - actual_n_cat_features = cat_item_embeddings.n_cat_features - actual_embedding_dim = cat_item_embeddings.category_embeddings.embedding_dim - - expected_item_features = dataset_item_features.item_features + actual_offsets = cat_item_embeddings.offsets + actual_n_cat_feature_values = cat_item_embeddings.n_cat_feature_values + actual_embedding_dim = cat_item_embeddings.embedding_bag.embedding_dim + actual_length_range = cat_item_embeddings.length_range + actual_emb_bag_inputs = cat_item_embeddings.emb_bag_inputs + actual_input_lengths = cat_item_embeddings.input_lengths - assert isinstance(expected_item_features, SparseFeatures) - expected_cat_item_features = expected_item_features.get_cat_features() + expected_offsets = torch.tensor([0, 0, 2, 4, 6, 8, 10]) + expected_length_range = torch.tensor([0, 1]) + expected_emb_bag_inputs = torch.tensor([0, 2, 1, 4, 0, 3, 1, 2, 1, 3, 1, 3]) + expected_input_lengths = torch.tensor([0, 2, 2, 2, 2, 2, 2]) - assert_feature_set_equal(actual_item_features, expected_cat_item_features) - assert actual_n_items == dataset_item_features.item_id_map.size - assert actual_n_cat_features == len(expected_cat_item_features.names) + assert actual_n_cat_feature_values == 5 assert actual_embedding_dim == n_factors + assert torch.equal(actual_offsets, expected_offsets) + assert torch.equal(actual_length_range, expected_length_range) + assert torch.equal(actual_emb_bag_inputs, expected_emb_bag_inputs) + assert torch.equal(actual_input_lengths, expected_input_lengths) @pytest.mark.parametrize( "n_items,n_factors", From 41b645fb96d53f01364d84c899b5e6ec324fe873 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Tue, 4 Feb 2025 18:49:12 +0300 Subject: [PATCH 3/5] added from_dataset_schema to CatFeaturesItemNet --- rectools/models/nn/item_net.py | 50 ++++++++++++++++++++++-- tests/models/nn/test_item_net.py | 3 -- tests/models/nn/test_transformer_base.py | 43 +++++++++++++++++--- 3 files changed, 84 insertions(+), 12 deletions(-) diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index dfa3d761..e3be82f2 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -36,7 +36,9 @@ def from_dataset(cls, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> tp.O raise NotImplementedError() @classmethod - def from_dataset_schema(cls, dataset_schema: DatasetSchema, *args: tp.Any, **kwargs: tp.Any) -> tpe.Self: + def from_dataset_schema( + cls, dataset_schema: DatasetSchema, *args: tp.Any, **kwargs: tp.Any + ) -> tp.Optional[tpe.Self]: """Construct ItemNet from Dataset schema.""" raise NotImplementedError() @@ -80,7 +82,6 @@ def __init__( self.drop_layer = nn.Dropout(dropout_rate) self.register_buffer("offsets", offsets) - self.register_buffer("length_range", torch.arange(input_lengths.max().item())) self.register_buffer("emb_bag_inputs", emb_bag_inputs) self.register_buffer("input_lengths", input_lengths) @@ -105,8 +106,9 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: def get_item_inputs_offsets(self, items: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Get categorical item features and offsets for `items`.""" - item_indexes = self.offsets[items].unsqueeze(-1) + self.length_range - length_mask = self.length_range < self.input_lengths[items].unsqueeze(-1) + length_range = torch.arange(self.input_lengths.max().item(), device=self.device) + item_indexes = self.offsets[items].unsqueeze(-1) + length_range + length_mask = length_range < self.input_lengths[items].unsqueeze(-1) item_emb_bag_inputs = self.emb_bag_inputs[item_indexes[length_mask]] item_offsets = torch.cat( (torch.tensor([0], device=self.device), torch.cumsum(self.input_lengths[items], dim=0)[:-1]) @@ -165,6 +167,46 @@ def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> dropout_rate=dropout_rate, ) + @classmethod + def from_dataset_schema( + cls, dataset_schema: DatasetSchema, n_factors: int, dropout_rate: float + ) -> tp.Optional[tpe.Self]: + """Construct CatFeaturesItemNet from Dataset schema.""" + if dataset_schema.items.features is None: + explanation = """Ignoring `CatFeaturesItemNet` block because dataset doesn't contain item features.""" + warnings.warn(explanation) + return None + + if dataset_schema.items.features.kind == "dense": + explanation = """ + Ignoring `CatFeaturesItemNet` block because + dataset item features are dense and unable to contain categorical features. + """ + warnings.warn(explanation) + return None + + if len(dataset_schema.items.features.cat_feature_indices) == 0: + explanation = """ + Ignoring `CatFeaturesItemNet` block because dataset item features do not contain categorical features. + """ + warnings.warn(explanation) + return None + + emb_bag_inputs = torch.randint( + high=dataset_schema.items.n_hot, size=(dataset_schema.items.features.cat_n_stored_values,) + ) + offsets = torch.randint(high=dataset_schema.items.n_hot, size=(dataset_schema.items.n_hot,)) + input_lengths = torch.randint(high=dataset_schema.items.n_hot, size=(dataset_schema.items.n_hot,)) + n_cat_feature_values = len(dataset_schema.items.features.cat_feature_indices) + return cls( + emb_bag_inputs=emb_bag_inputs, + offsets=offsets, + input_lengths=input_lengths, + n_cat_feature_values=n_cat_feature_values, + n_factors=n_factors, + dropout_rate=dropout_rate, + ) + class IdEmbeddingsItemNet(ItemNetBase): """ diff --git a/tests/models/nn/test_item_net.py b/tests/models/nn/test_item_net.py index d87281db..cb47df29 100644 --- a/tests/models/nn/test_item_net.py +++ b/tests/models/nn/test_item_net.py @@ -120,19 +120,16 @@ def test_create_from_dataset(self, n_factors: int, dataset_item_features: Datase actual_offsets = cat_item_embeddings.offsets actual_n_cat_feature_values = cat_item_embeddings.n_cat_feature_values actual_embedding_dim = cat_item_embeddings.embedding_bag.embedding_dim - actual_length_range = cat_item_embeddings.length_range actual_emb_bag_inputs = cat_item_embeddings.emb_bag_inputs actual_input_lengths = cat_item_embeddings.input_lengths expected_offsets = torch.tensor([0, 0, 2, 4, 6, 8, 10]) - expected_length_range = torch.tensor([0, 1]) expected_emb_bag_inputs = torch.tensor([0, 2, 1, 4, 0, 3, 1, 2, 1, 3, 1, 3]) expected_input_lengths = torch.tensor([0, 2, 2, 2, 2, 2, 2]) assert actual_n_cat_feature_values == 5 assert actual_embedding_dim == n_factors assert torch.equal(actual_offsets, expected_offsets) - assert torch.equal(actual_length_range, expected_length_range) assert torch.equal(actual_emb_bag_inputs, expected_emb_bag_inputs) assert torch.equal(actual_input_lengths, expected_input_lengths) diff --git a/tests/models/nn/test_transformer_base.py b/tests/models/nn/test_transformer_base.py index df6f2c25..c4e868e5 100644 --- a/tests/models/nn/test_transformer_base.py +++ b/tests/models/nn/test_transformer_base.py @@ -26,10 +26,11 @@ from rectools import Columns from rectools.dataset import Dataset from rectools.models import BERT4RecModel, SASRecModel, load_model -from rectools.models.nn.item_net import IdEmbeddingsItemNet +from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet from rectools.models.nn.transformer_base import TransformerModelBase from tests.models.utils import assert_save_load_do_not_change_model +from ..data import INTERACTIONS from .utils import custom_trainer, leave_one_out_mask @@ -68,6 +69,38 @@ def interactions_df(self) -> pd.DataFrame: def dataset(self, interactions_df: pd.DataFrame) -> Dataset: return Dataset.construct(interactions_df) + @pytest.fixture + def dataset_item_features(self) -> Dataset: + item_features = pd.DataFrame( + [ + [12, "f1", "f1val1"], + [12, "f2", "f2val2"], + [13, "f1", "f1val1"], + [13, "f2", "f2val3"], + [14, "f1", "f1val2"], + [14, "f2", "f2val1"], + [15, "f1", "f1val2"], + [15, "f2", "f2val2"], + [17, "f1", "f1val2"], + [17, "f2", "f2val3"], + [16, "f1", "f1val2"], + [16, "f2", "f2val3"], + [12, "f3", 1], + [13, "f3", 2], + [14, "f3", 3], + [15, "f3", 4], + [17, "f3", 5], + [16, "f3", 6], + ], + columns=["id", "feature", "value"], + ) + ds = Dataset.construct( + INTERACTIONS, + item_features_df=item_features, + cat_item_features=["f1", "f2"], + ) + return ds + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) @pytest.mark.parametrize("default_trainer", (True, False)) def test_save_load_for_unfitted_model( @@ -123,12 +156,12 @@ def test_load_from_checkpoint( self, model_cls: tp.Type[TransformerModelBase], tmp_path: str, - dataset: Dataset, + dataset_item_features: Dataset, ) -> None: model = model_cls.from_config( { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet,), # TODO: add CatFeaturesItemNet + "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), } ) model._trainer = Trainer( # pylint: disable=protected-access @@ -140,7 +173,7 @@ def test_load_from_checkpoint( devices=1, callbacks=ModelCheckpoint(filename="last_epoch"), ) - model.fit(dataset) + model.fit(dataset_item_features) assert model.fit_trainer is not None if model.fit_trainer.log_dir is None: @@ -150,7 +183,7 @@ def test_load_from_checkpoint( recovered_model = model_cls.load_from_checkpoint(ckpt_path) assert isinstance(recovered_model, model_cls) - self._assert_same_reco(model, recovered_model, dataset) + self._assert_same_reco(model, recovered_model, dataset_item_features) @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) @pytest.mark.parametrize("verbose", (1, 0)) From 1c3959a968a8d995b00aafd7e3f336707017d731 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Wed, 5 Feb 2025 12:42:33 +0300 Subject: [PATCH 4/5] fixed TODO in test_transformer_base.py --- rectools/models/nn/transformer_base.py | 1 - tests/models/nn/test_item_net.py | 3 --- tests/models/nn/test_transformer_base.py | 14 +++++++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/rectools/models/nn/transformer_base.py b/rectools/models/nn/transformer_base.py index a07179b5..45bafbaa 100644 --- a/rectools/models/nn/transformer_base.py +++ b/rectools/models/nn/transformer_base.py @@ -48,7 +48,6 @@ TransformerLayersBase, ) - # #### -------------- Transformer Model Config -------------- #### # diff --git a/tests/models/nn/test_item_net.py b/tests/models/nn/test_item_net.py index 49a5de63..69f641ea 100644 --- a/tests/models/nn/test_item_net.py +++ b/tests/models/nn/test_item_net.py @@ -22,8 +22,6 @@ from rectools.columns import Columns from rectools.dataset import Dataset - -from rectools.dataset.features import SparseFeatures from rectools.models.nn.item_net import ( CatFeaturesItemNet, IdEmbeddingsItemNet, @@ -31,7 +29,6 @@ ItemNetConstructorBase, SumOfEmbeddingsConstructor, ) -from tests.testing_utils import assert_feature_set_equal from ..data import DATASET, INTERACTIONS diff --git a/tests/models/nn/test_transformer_base.py b/tests/models/nn/test_transformer_base.py index c4e868e5..1344e7a7 100644 --- a/tests/models/nn/test_transformer_base.py +++ b/tests/models/nn/test_transformer_base.py @@ -108,7 +108,7 @@ def test_save_load_for_unfitted_model( ) -> None: config = { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet,), # TODO: add CatFeaturesItemNet + "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), } if not default_trainer: config["get_trainer_func"] = custom_trainer @@ -139,17 +139,21 @@ def _assert_same_reco(self, model_1: TransformerModelBase, model_2: TransformerM @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) @pytest.mark.parametrize("default_trainer", (True, False)) def test_save_load_for_fitted_model( - self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset, default_trainer: bool, trainer: Trainer + self, + model_cls: tp.Type[TransformerModelBase], + dataset_item_features: Dataset, + default_trainer: bool, + trainer: Trainer, ) -> None: config = { "deterministic": True, - "item_net_block_types": (IdEmbeddingsItemNet,), # TODO: add CatFeaturesItemNet + "item_net_block_types": (IdEmbeddingsItemNet, CatFeaturesItemNet), } if not default_trainer: config["get_trainer_func"] = custom_trainer model = model_cls.from_config(config) - model.fit(dataset) - assert_save_load_do_not_change_model(model, dataset) + model.fit(dataset_item_features) + assert_save_load_do_not_change_model(model, dataset_item_features) @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) def test_load_from_checkpoint( From 703db1a11d28e19bf9ecdc87ca0b82a8a7a78e31 Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Thu, 6 Feb 2025 15:51:53 +0300 Subject: [PATCH 5/5] added mode=sum to embedding_bag --- examples/bert4rec.ipynb | 587 ------ examples/sasrec_metrics_comp.ipynb | 2926 ---------------------------- rectools/models/nn/item_net.py | 2 +- 3 files changed, 1 insertion(+), 3514 deletions(-) delete mode 100644 examples/bert4rec.ipynb delete mode 100644 examples/sasrec_metrics_comp.ipynb diff --git a/examples/bert4rec.ipynb b/examples/bert4rec.ipynb deleted file mode 100644 index 9bd18521..00000000 --- a/examples/bert4rec.ipynb +++ /dev/null @@ -1,587 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import torch\n", - "import threadpoolctl\n", - "from pathlib import Path\n", - "from lightning_fabric import seed_everything\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "from rectools import Columns\n", - "\n", - "\n", - "from rectools.dataset import Dataset\n", - "from rectools.metrics import MAP, calc_metrics, MeanInvUserFreq, Serendipity\n", - "from rectools.models import BERT4RecModel\n", - "from rectools.models.nn.item_net import IdEmbeddingsItemNet" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", - "os.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\n", - "threadpoolctl.threadpool_limits(1, \"blas\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Prepare data" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# %%time\n", - "# !wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_original.zip\n", - "# !unzip -o data_original.zip\n", - "# !rm data_original.zip" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "DATA_PATH = Path(\"data_original\")\n", - "\n", - "interactions = (\n", - " pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=[\"last_watch_dt\"])\n", - " .rename(columns={\"last_watch_dt\": \"datetime\"})\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)\n", - "\n", - "# Split to train / test\n", - "max_date = interactions[Columns.Datetime].max()\n", - "train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()\n", - "test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()\n", - "train.drop(train.query(\"total_dur < 300\").index, inplace=True)\n", - "\n", - "# drop items with less than 20 interactions in train\n", - "items = train[\"item_id\"].value_counts()\n", - "items = items[items >= 20]\n", - "items = items.index.to_list()\n", - "train = train[train[\"item_id\"].isin(items)]\n", - " \n", - "# drop users with less than 2 interactions in train\n", - "users = train[\"user_id\"].value_counts()\n", - "users = users[users >= 2]\n", - "users = users.index.to_list()\n", - "train = train[(train[\"user_id\"].isin(users))]\n", - "\n", - "users = train[\"user_id\"].drop_duplicates().to_list()\n", - "\n", - "# drop cold users from test\n", - "test_users_sasrec = test[Columns.User].unique()\n", - "cold_users = set(test[Columns.User]) - set(train[Columns.User])\n", - "test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)\n", - "test_users = test[Columns.User].unique()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "items = pd.read_csv(DATA_PATH / 'items.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# Process item features to the form of a flatten dataframe\n", - "items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()\n", - "items[\"genre\"] = items[\"genres\"].str.lower().str.replace(\", \", \",\", regex=False).str.split(\",\")\n", - "genre_feature = items[[\"item_id\", \"genre\"]].explode(\"genre\")\n", - "genre_feature.columns = [\"id\", \"value\"]\n", - "genre_feature[\"feature\"] = \"genre\"\n", - "content_feature = items.reindex(columns=[Columns.Item, \"content_type\"])\n", - "content_feature.columns = [\"id\", \"value\"]\n", - "content_feature[\"feature\"] = \"content_type\"\n", - "item_features = pd.concat((genre_feature, content_feature))\n", - "\n", - "candidate_items = interactions['item_id'].drop_duplicates().astype(int)\n", - "test[\"user_id\"] = test[\"user_id\"].astype(int)\n", - "test[\"item_id\"] = test[\"item_id\"].astype(int)\n", - "\n", - "catalog=train[Columns.Item].unique()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "dataset_no_features = Dataset.construct(\n", - " interactions_df=train,\n", - ")\n", - "\n", - "dataset_item_features = Dataset.construct(\n", - " interactions_df=train,\n", - " item_features_df=item_features,\n", - " cat_item_features=[\"genre\", \"content_type\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "metrics_name = {\n", - " 'MAP': MAP,\n", - " 'MIUF': MeanInvUserFreq,\n", - " 'Serendipity': Serendipity\n", - " \n", - "\n", - "}\n", - "metrics = {}\n", - "for metric_name, metric in metrics_name.items():\n", - " for k in (1, 5, 10):\n", - " metrics[f'{metric_name}@{k}'] = metric(k=k)\n", - "\n", - "# list with metrics results of all models\n", - "features_results = []\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# BERT4Rec" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 32\n" - ] - }, - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "RANDOM_SEED = 32\n", - "torch.use_deterministic_algorithms(True)\n", - "seed_everything(RANDOM_SEED, workers=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### BERT4Rec with item ids embeddings in ItemNetBlock" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "model = BERT4RecModel(\n", - " n_blocks=3,\n", - " n_heads=4,\n", - " dropout_rate=0.2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " mask_prob=0.5,\n", - " deterministic=True,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params\n", - "---------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 1.3 M \n", - "---------------------------------------------------------------\n", - "1.3 M Trainable params\n", - "0 Non-trainable params\n", - "1.3 M Total params\n", - "5.292 Total estimated model params size (MB)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fee32d8cfa0147689a003b073d52fbd5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:313: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/data/home/amsemenov2/git/RecTools_origin/RecTools/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "af7987f2168242baa2476d3b6c955440", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
07344697280.6417741
17344677930.4246962
27344678290.3713493
373446143170.3570834
47344631820.0686615
...............
94704585716237341.2114546
94704685716286361.1050207
94704785716241511.0535188
947048857162147030.8832019
947049857162129950.82724610
\n", - "

947050 rows × 4 columns

\n", - "" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 73446 9728 0.641774 1\n", - "1 73446 7793 0.424696 2\n", - "2 73446 7829 0.371349 3\n", - "3 73446 14317 0.357083 4\n", - "4 73446 3182 0.068661 5\n", - "... ... ... ... ...\n", - "947045 857162 3734 1.211454 6\n", - "947046 857162 8636 1.105020 7\n", - "947047 857162 4151 1.053518 8\n", - "947048 857162 14703 0.883201 9\n", - "947049 857162 12995 0.827246 10\n", - "\n", - "[947050 rows x 4 columns]" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recos" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'MAP@1': 0.04615447203061492,\n", - " 'MAP@5': 0.07738831584888614,\n", - " 'MAP@10': 0.08574348640312766,\n", - " 'MIUF@1': 3.904076196457114,\n", - " 'MIUF@5': 4.52025063365768,\n", - " 'MIUF@10': 5.012999210434636,\n", - " 'Serendipity@1': 0.0004970568200398246,\n", - " 'Serendipity@5': 0.000480108457862388,\n", - " 'Serendipity@10': 0.00045721034938317576,\n", - " 'model': 'bert4rec_ids'}]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "features_results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "rectools_origin", - "language": "python", - "name": "rectools_origin" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/sasrec_metrics_comp.ipynb b/examples/sasrec_metrics_comp.ipynb deleted file mode 100644 index d2131880..00000000 --- a/examples/sasrec_metrics_comp.ipynb +++ /dev/null @@ -1,2926 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.append(\"../\")" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import logging\n", - "import os\n", - "import threadpoolctl\n", - "import torch\n", - "from pathlib import Path\n", - "from lightning_fabric import seed_everything\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "from rectools import Columns\n", - "\n", - "from implicit.als import AlternatingLeastSquares\n", - "\n", - "from rectools.dataset import Dataset\n", - "from rectools.metrics import MAP, calc_metrics, MeanInvUserFreq, Serendipity\n", - "from rectools.models import ImplicitALSWrapperModel\n", - "from rectools.models import SASRecModel\n", - "from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", - "\n", - "# For implicit ALS\n", - "os.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\n", - "threadpoolctl.threadpool_limits(1, \"blas\")\n", - "\n", - "logging.basicConfig()\n", - "logging.getLogger().setLevel(logging.INFO)\n", - "logger = logging.getLogger()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Data" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# %%time\n", - "# !wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_original.zip\n", - "# !unzip -o data_original.zip\n", - "# !rm data_original.zip" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "DATA_PATH = Path(\"data_original\")\n", - "\n", - "interactions = (\n", - " pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=[\"last_watch_dt\"])\n", - " .rename(columns={\"last_watch_dt\": \"datetime\"})\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Split dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)\n", - "\n", - "# Split to train / test\n", - "max_date = interactions[Columns.Datetime].max()\n", - "train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()\n", - "test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()\n", - "train.drop(train.query(\"total_dur < 300\").index, inplace=True)\n", - "\n", - "# drop items with less than 20 interactions in train\n", - "items = train[\"item_id\"].value_counts()\n", - "items = items[items >= 20]\n", - "items = items.index.to_list()\n", - "train = train[train[\"item_id\"].isin(items)]\n", - " \n", - "# drop users with less than 2 interactions in train\n", - "users = train[\"user_id\"].value_counts()\n", - "users = users[users >= 2]\n", - "users = users.index.to_list()\n", - "train = train[(train[\"user_id\"].isin(users))]\n", - "\n", - "users = train[\"user_id\"].drop_duplicates().to_list()\n", - "\n", - "# drop cold users from test\n", - "test_users_sasrec = test[Columns.User].unique()\n", - "cold_users = set(test[Columns.User]) - set(train[Columns.User])\n", - "test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)\n", - "test_users = test[Columns.User].unique()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "items = pd.read_csv(DATA_PATH / 'items.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# Process item features to the form of a flatten dataframe\n", - "items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()\n", - "\n", - "items[\"genre\"] = items[\"genres\"].str.lower().str.replace(\", \", \",\", regex=False).str.split(\",\")\n", - "genre_feature = items[[\"item_id\", \"genre\"]].explode(\"genre\")\n", - "genre_feature.columns = [\"id\", \"value\"]\n", - "genre_feature[\"feature\"] = \"genre\"\n", - "\n", - "items[\"director\"] = items[\"directors\"].str.lower().replace(\", \", \",\", regex=False).str.split(\",\")\n", - "directors_feature = items[[\"item_id\", \"director\"]].explode(\"director\")\n", - "directors_feature.columns = [\"id\", \"value\"]\n", - "directors_feature[\"feature\"] = \"director\"\n", - "\n", - "content_feature = items.reindex(columns=[Columns.Item, \"content_type\"])\n", - "content_feature.columns = [\"id\", \"value\"]\n", - "content_feature[\"feature\"] = \"content_type\"\n", - "item_features_genre_content = pd.concat((genre_feature, content_feature))\n", - "item_features_genre_director = pd.concat((genre_feature, directors_feature))\n", - "\n", - "candidate_items = interactions['item_id'].drop_duplicates().astype(int)\n", - "test[\"user_id\"] = test[\"user_id\"].astype(int)\n", - "test[\"item_id\"] = test[\"item_id\"].astype(int)\n", - "\n", - "catalog=train[Columns.Item].unique()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "dataset_no_features = Dataset.construct(\n", - " interactions_df=train,\n", - ")\n", - "\n", - "dataset_item_features = Dataset.construct(\n", - " interactions_df=train,\n", - " item_features_df=item_features_genre_content,\n", - " cat_item_features=[\"genre\", \"content_type\"],\n", - ")\n", - "\n", - "dataset_item_features_genre_director = Dataset.construct(\n", - " interactions_df=train,\n", - " item_features_df=item_features_genre_director,\n", - " cat_item_features=[\"genre\", \"director\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "metrics_name = {\n", - " 'MAP': MAP,\n", - " 'MIUF': MeanInvUserFreq,\n", - " 'Serendipity': Serendipity\n", - " \n", - "\n", - "}\n", - "metrics = {}\n", - "for metric_name, metric in metrics_name.items():\n", - " for k in (1, 5, 10):\n", - " metrics[f'{metric_name}@{k}'] = metric(k=k)\n", - "\n", - "# list with metrics results of all models\n", - "features_results = []" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# SASRec" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 32\n" - ] - }, - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "RANDOM_SEED = 32\n", - "torch.use_deterministic_algorithms(True)\n", - "seed_everything(RANDOM_SEED, workers=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Softmax loss" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n" - ] - } - ], - "source": [ - "model = SASRecModel(\n", - " n_blocks=2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " deterministic=True,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", - " unq_values = pd.unique(values)\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params | Mode \n", - "-----------------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 2.2 M | train\n", - "-----------------------------------------------------------------------\n", - "2.2 M Trainable params\n", - "0 Non-trainable params\n", - "2.2 M Total params\n", - "8.991 Total estimated model params size (MB)\n", - "36 Modules in train mode\n", - "0 Modules in eval mode\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "06699b98353c4d2298bb1037f77fdf71", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_item_features_genre_director)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6a4509eead0042a28799087cefb68d47", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
07344697282.4018811
17344677931.9230692
27344637841.8246133
37344631821.6665284
47344678291.6621765
...............
947045857162129952.3854326
94704685716268092.3609357
9470478571626571.9409318
94704885716247021.8664799
947049857162164471.75802710
\n", - "

947050 rows × 4 columns

\n", - "" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 73446 9728 2.401881 1\n", - "1 73446 7793 1.923069 2\n", - "2 73446 3784 1.824613 3\n", - "3 73446 3182 1.666528 4\n", - "4 73446 7829 1.662176 5\n", - "... ... ... ... ...\n", - "947045 857162 12995 2.385432 6\n", - "947046 857162 6809 2.360935 7\n", - "947047 857162 657 1.940931 8\n", - "947048 857162 4702 1.866479 9\n", - "947049 857162 16447 1.758027 10\n", - "\n", - "[947050 rows x 4 columns]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recos" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## BCE loss" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 32\n" - ] - }, - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "RANDOM_SEED = 32\n", - "torch.use_deterministic_algorithms(True)\n", - "seed_everything(RANDOM_SEED, workers=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" - ] - } - ], - "source": [ - "model = SASRecModel(\n", - " n_blocks=2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " deterministic=True,\n", - " loss=\"BCE\",\n", - " n_negatives=2,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", - " unq_values = pd.unique(values)\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params | Mode \n", - "-----------------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 2.2 M | train\n", - "-----------------------------------------------------------------------\n", - "2.2 M Trainable params\n", - "0 Non-trainable params\n", - "2.2 M Total params\n", - "8.991 Total estimated model params size (MB)\n", - "36 Modules in train mode\n", - "0 Modules in eval mode\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "977a6b52aaf24f4fa0ccb05a7c9804ba", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4aa1c3be58ae41bb9aa81274121df9ed", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_iditem_idscorerank
07344631823.3702861
173446129653.0880012
27344667743.0569053
373446162702.9669684
47344675822.9657085
...............
94704585716241512.7330066
9470468571621422.6873157
94704785716297282.6347418
94704885716237342.5589339
94704985716299962.47984910
\n", - "

947050 rows × 4 columns

\n", - "" - ], - "text/plain": [ - " user_id item_id score rank\n", - "0 73446 3182 3.370286 1\n", - "1 73446 12965 3.088001 2\n", - "2 73446 6774 3.056905 3\n", - "3 73446 16270 2.966968 4\n", - "4 73446 7582 2.965708 5\n", - "... ... ... ... ...\n", - "947045 857162 4151 2.733006 6\n", - "947046 857162 142 2.687315 7\n", - "947047 857162 9728 2.634741 8\n", - "947048 857162 3734 2.558933 9\n", - "947049 857162 9996 2.479849 10\n", - "\n", - "[947050 rows x 4 columns]" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recos" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## gBCE loss" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 32\n" - ] - }, - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "RANDOM_SEED = 32\n", - "torch.use_deterministic_algorithms(True)\n", - "seed_everything(RANDOM_SEED, workers=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" - ] - } - ], - "source": [ - "model = SASRecModel(\n", - " n_blocks=2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " deterministic=True,\n", - " loss=\"gBCE\",\n", - " n_negatives=256,\n", - " gbce_t=0.75,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", - " unq_values = pd.unique(values)\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params | Mode \n", - "-----------------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 2.2 M | train\n", - "-----------------------------------------------------------------------\n", - "2.2 M Trainable params\n", - "0 Non-trainable params\n", - "2.2 M Total params\n", - "8.991 Total estimated model params size (MB)\n", - "36 Modules in train mode\n", - "0 Modules in eval mode\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ccf86bc484ed4bbe804a088a765ff054", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8e657a67ee204507a61fb6e6489248bd", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b2a472b585d44beb930aca7c172982cc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
MAP@1MAP@5MAP@10MIUF@1MIUF@5MIUF@10Serendipity@1Serendipity@5Serendipity@10
model
softmax_padding_mask0.0494490.0833990.0925353.7036824.3036754.9371420.0009600.0007270.000670
softmax0.0484660.0816950.0907043.8714264.5730695.1597420.0011170.0008650.000763
gBCE0.0408480.0723560.0801662.3323973.0937633.9422050.0001030.0001180.000134
bce0.0270350.0512440.0590803.8820814.3843144.7342980.0001040.0001210.000131
\n", - "" - ], - "text/plain": [ - " MAP@1 MAP@5 MAP@10 MIUF@1 MIUF@5 \\\n", - "model \n", - "softmax_padding_mask 0.049449 0.083399 0.092535 3.703682 4.303675 \n", - "softmax 0.048466 0.081695 0.090704 3.871426 4.573069 \n", - "gBCE 0.040848 0.072356 0.080166 2.332397 3.093763 \n", - "bce 0.027035 0.051244 0.059080 3.882081 4.384314 \n", - "\n", - " MIUF@10 Serendipity@1 Serendipity@5 Serendipity@10 \n", - "model \n", - "softmax_padding_mask 4.937142 0.000960 0.000727 0.000670 \n", - "softmax 5.159742 0.001117 0.000865 0.000763 \n", - "gBCE 3.942205 0.000103 0.000118 0.000134 \n", - "bce 4.734298 0.000104 0.000121 0.000131 " - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "features_df = (\n", - " pd.DataFrame(features_results)\n", - " .set_index(\"model\")\n", - " .sort_values(by=[\"MAP@10\", \"Serendipity@10\"], ascending=False)\n", - ")\n", - "features_df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### sasrec with item ids embeddings in ItemNetBlock" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" - ] - } - ], - "source": [ - "model = SASRecModel(\n", - " n_blocks=2,\n", - " session_max_len=32,\n", - " lr=1e-3,\n", - " epochs=5,\n", - " verbose=1,\n", - " deterministic=True,\n", - " item_net_block_types=(IdEmbeddingsItemNet, ), # Use only item ids in ItemNetBlock\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/dataset/identifiers.py:60: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.\n", - " unq_values = pd.unique(values)\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "\n", - " | Name | Type | Params | Mode \n", - "-----------------------------------------------------------------------\n", - "0 | torch_model | TransformerBasedSessionEncoder | 2.2 M | train\n", - "-----------------------------------------------------------------------\n", - "2.2 M Trainable params\n", - "0 Non-trainable params\n", - "2.2 M Total params\n", - "8.991 Total estimated model params size (MB)\n", - "36 Modules in train mode\n", - "0 Modules in eval mode\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "115180caaade47eb91e2121666f5bb31", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%%time\n", - "model.fit(dataset_no_features)" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "31ef580c7fa3448d9f0e6122149377a6", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#%%time\n", - "model.fit(dataset_item_features_genre_director)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "970a1097f4d947f3b4d841eeed2e7a61", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#%%time\n", - "model.fit(dataset_item_features_genre_director)" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/nn/transformer_data_preparator.py:322: UserWarning: 91202 target users were considered cold because of missing known items\n", - " warnings.warn(explanation)\n", - "/data/home/maspirina1/git_repos/RecTools/examples/../rectools/models/base.py:675: UserWarning: \n", - " Model `` doesn't support recommendations for cold users,\n", - " but some of given users are cold: they are not in the `dataset.user_id_map`\n", - " \n", - " warnings.warn(explanation)\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", - "/home/maspirina1/git_repos/RecTools/venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=143` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "aee07340143145f080a9a1abc19dc6fb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Predicting: | | 0/? [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
MAP@1MAP@5MAP@10MIUF@1MIUF@5MIUF@10Serendipity@1Serendipity@5Serendipity@10
model
softmax_padding_mask0.0494490.0833990.0925353.7036824.3036754.9371420.0009600.0007270.000670
sasrec_ids0.0482320.0822870.0914953.6332644.3088514.9660900.0010090.0008030.000729
softmax0.0484660.0816950.0907043.8714264.5730695.1597420.0011170.0008650.000763
sasrec_id_and_cat_features0.0469810.0806450.0900084.3303474.9743375.5104500.0015170.0011090.000966
gBCE0.0408480.0723560.0801662.3323973.0937633.9422050.0001030.0001180.000134
sasrec_cat_features0.0440480.0716680.0795154.0098855.5158726.0063210.0009660.0008130.000715
bce0.0270350.0512440.0590803.8820814.3843144.7342980.0001040.0001210.000131
\n", - "" - ], - "text/plain": [ - " MAP@1 MAP@5 MAP@10 MIUF@1 MIUF@5 \\\n", - "model \n", - "softmax_padding_mask 0.049449 0.083399 0.092535 3.703682 4.303675 \n", - "sasrec_ids 0.048232 0.082287 0.091495 3.633264 4.308851 \n", - "softmax 0.048466 0.081695 0.090704 3.871426 4.573069 \n", - "sasrec_id_and_cat_features 0.046981 0.080645 0.090008 4.330347 4.974337 \n", - "gBCE 0.040848 0.072356 0.080166 2.332397 3.093763 \n", - "sasrec_cat_features 0.044048 0.071668 0.079515 4.009885 5.515872 \n", - "bce 0.027035 0.051244 0.059080 3.882081 4.384314 \n", - "\n", - " MIUF@10 Serendipity@1 Serendipity@5 \\\n", - "model \n", - "softmax_padding_mask 4.937142 0.000960 0.000727 \n", - "sasrec_ids 4.966090 0.001009 0.000803 \n", - "softmax 5.159742 0.001117 0.000865 \n", - "sasrec_id_and_cat_features 5.510450 0.001517 0.001109 \n", - "gBCE 3.942205 0.000103 0.000118 \n", - "sasrec_cat_features 6.006321 0.000966 0.000813 \n", - "bce 4.734298 0.000104 0.000121 \n", - "\n", - " Serendipity@10 \n", - "model \n", - "softmax_padding_mask 0.000670 \n", - "sasrec_ids 0.000729 \n", - "softmax 0.000763 \n", - "sasrec_id_and_cat_features 0.000966 \n", - "gBCE 0.000134 \n", - "sasrec_cat_features 0.000715 \n", - "bce 0.000131 " - ] - }, - "execution_count": 54, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "features_df = (\n", - " pd.DataFrame(features_results)\n", - " .set_index(\"model\")\n", - " .sort_values(by=[\"MAP@10\", \"Serendipity@10\"], ascending=False)\n", - ")\n", - "features_df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Item to item" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [], - "source": [ - "target_items = [13865, 4457, 15297]" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 10.3 s, sys: 600 ms, total: 10.9 s\n", - "Wall time: 1.68 s\n" - ] - } - ], - "source": [ - "%%time\n", - "recos = model.recommend_to_items(\n", - " target_items=target_items, \n", - " dataset=dataset_no_features,\n", - " k=10,\n", - " filter_itself=True,\n", - " items_to_recommend=None, #white_list,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
target_item_iditem_idscorerank
013865118631.0000001
11386571071.0000002
21386564090.5595783
3138651420.5321084
41386544570.5313605
51386568090.5310586
613865104400.5225967
71386526570.5128298
813865107720.5113749
913865152970.47755310
104457147410.7431631
114457129950.6443022
12445764550.6295463
1344571420.6243344
14445712870.6227135
1544572740.6091026
164457107720.6042507
17445739350.6010688
18445735090.5858279
19445793420.57565110
20152971420.6124351
211529726570.5655472
221529768090.5651213
2315297104400.5628964
2415297107720.5294995
251529739350.5205206
261529744570.5158627
2715297143370.5063328
281529718440.5063329
291529786360.50372910
\n", - "
" - ], - "text/plain": [ - " target_item_id item_id score rank\n", - "0 13865 11863 1.000000 1\n", - "1 13865 7107 1.000000 2\n", - "2 13865 6409 0.559578 3\n", - "3 13865 142 0.532108 4\n", - "4 13865 4457 0.531360 5\n", - "5 13865 6809 0.531058 6\n", - "6 13865 10440 0.522596 7\n", - "7 13865 2657 0.512829 8\n", - "8 13865 10772 0.511374 9\n", - "9 13865 15297 0.477553 10\n", - "10 4457 14741 0.743163 1\n", - "11 4457 12995 0.644302 2\n", - "12 4457 6455 0.629546 3\n", - "13 4457 142 0.624334 4\n", - "14 4457 1287 0.622713 5\n", - "15 4457 274 0.609102 6\n", - "16 4457 10772 0.604250 7\n", - "17 4457 3935 0.601068 8\n", - "18 4457 3509 0.585827 9\n", - "19 4457 9342 0.575651 10\n", - "20 15297 142 0.612435 1\n", - "21 15297 2657 0.565547 2\n", - "22 15297 6809 0.565121 3\n", - "23 15297 10440 0.562896 4\n", - "24 15297 10772 0.529499 5\n", - "25 15297 3935 0.520520 6\n", - "26 15297 4457 0.515862 7\n", - "27 15297 14337 0.506332 8\n", - "28 15297 1844 0.506332 9\n", - "29 15297 8636 0.503729 10" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "recos" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
target_item_iditem_idscoreranktitle
013865118631.0000001Девятаев - сериал
11386571071.0000002Девятаев
21386564090.5595783Особо опасен
3138651420.5321084Маша
41386544570.53136052067: Петля времени
51386568090.5310586Дуров
613865104400.5225967Хрустальный
71386526570.5128298Подслушано
813865107720.5113749Зелёная книга
913865152970.47755310Клиника счастья
104457147410.7431631Цвет из иных миров
114457129950.6443022Восемь сотен
12445764550.6295463Альфа
1344571420.6243344Маша
14445712870.6227135Терминатор: Тёмные судьбы
1544572740.6091026Логан
164457107720.6042507Зелёная книга
17445739350.6010688Бывшая с того света
18445735090.5858279Комната желаний
19445793420.57565110Дэдпул
20152971420.6124351Маша
211529726570.5655472Подслушано
221529768090.5651213Дуров
2315297104400.5628964Хрустальный
2415297107720.5294995Зелёная книга
251529739350.5205206Бывшая с того света
261529744570.51586272067: Петля времени
2715297143370.5063328[4К] Аферистка
281529718440.5063329Аферистка
291529786360.50372910Белый снег
\n", - "
" - ], - "text/plain": [ - " target_item_id item_id score rank title\n", - "0 13865 11863 1.000000 1 Девятаев - сериал\n", - "1 13865 7107 1.000000 2 Девятаев\n", - "2 13865 6409 0.559578 3 Особо опасен\n", - "3 13865 142 0.532108 4 Маша\n", - "4 13865 4457 0.531360 5 2067: Петля времени\n", - "5 13865 6809 0.531058 6 Дуров\n", - "6 13865 10440 0.522596 7 Хрустальный\n", - "7 13865 2657 0.512829 8 Подслушано\n", - "8 13865 10772 0.511374 9 Зелёная книга\n", - "9 13865 15297 0.477553 10 Клиника счастья\n", - "10 4457 14741 0.743163 1 Цвет из иных миров\n", - "11 4457 12995 0.644302 2 Восемь сотен\n", - "12 4457 6455 0.629546 3 Альфа\n", - "13 4457 142 0.624334 4 Маша\n", - "14 4457 1287 0.622713 5 Терминатор: Тёмные судьбы\n", - "15 4457 274 0.609102 6 Логан\n", - "16 4457 10772 0.604250 7 Зелёная книга\n", - "17 4457 3935 0.601068 8 Бывшая с того света\n", - "18 4457 3509 0.585827 9 Комната желаний\n", - "19 4457 9342 0.575651 10 Дэдпул\n", - "20 15297 142 0.612435 1 Маша\n", - "21 15297 2657 0.565547 2 Подслушано\n", - "22 15297 6809 0.565121 3 Дуров\n", - "23 15297 10440 0.562896 4 Хрустальный\n", - "24 15297 10772 0.529499 5 Зелёная книга\n", - "25 15297 3935 0.520520 6 Бывшая с того света\n", - "26 15297 4457 0.515862 7 2067: Петля времени\n", - "27 15297 14337 0.506332 8 [4К] Аферистка\n", - "28 15297 1844 0.506332 9 Аферистка\n", - "29 15297 8636 0.503729 10 Белый снег" - ] - }, - "execution_count": 58, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# TODO: change model for recos (here is the last one trained and is is the worst in quality)\n", - "recos.merge(items[[\"item_id\", \"title\"]], on=\"item_id\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index 5009764c..d215f7c6 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -78,7 +78,7 @@ def __init__( super().__init__() self.n_cat_feature_values = n_cat_feature_values - self.embedding_bag = nn.EmbeddingBag(num_embeddings=n_cat_feature_values, embedding_dim=n_factors) + self.embedding_bag = nn.EmbeddingBag(num_embeddings=n_cat_feature_values, embedding_dim=n_factors, mode="sum") self.drop_layer = nn.Dropout(dropout_rate) self.register_buffer("offsets", offsets)