diff --git a/CHANGELOG.md b/CHANGELOG.md index b8ab28fe..a9eb3c3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added +- Methods for conversion `Interactions` to raw form and for getting raw interactions from `Dataset` ([#69](https://github.com/MobileTeleSystems/RecTools/pull/69)) + ### Changed - Loosened `pandas`, `torch` and `torch-light` versions for `python >= 3.8` ([#58](https://github.com/MobileTeleSystems/RecTools/pull/58)) diff --git a/rectools/dataset/dataset.py b/rectools/dataset/dataset.py index 4e318fe0..1f836d11 100644 --- a/rectools/dataset/dataset.py +++ b/rectools/dataset/dataset.py @@ -167,8 +167,6 @@ def get_user_item_matrix(self, include_weights: bool = True) -> sparse.csr_matri """ Construct user-item CSR matrix based on `interactions` attribute. - `Interactions.get_user_item_matrix` is used, see its documentation for details. - Return a resized user-item matrix. Resizing is done using `user_id_map` and `item_id_map`, hence if either a user or an item is not presented in interactions, @@ -188,3 +186,20 @@ def get_user_item_matrix(self, include_weights: bool = True) -> sparse.csr_matri matrix = self.interactions.get_user_item_matrix(include_weights) matrix.resize(self.user_id_map.internal_ids.size, self.item_id_map.internal_ids.size) return matrix + + def get_raw_interactions(self, include_weight: bool = True, include_datetime: bool = True) -> pd.DataFrame: + """ + Return iteractions as a `pd.DataFrame` object with replacing internal user and item ids to external ones. + + Parameters + ---------- + include_weight : bool, default ``True`` + Whether to include weight column into resulting table or not. + include_datetime : bool, default ``True`` + Whether to include datetime column into resulting table or not. + + Returns + ------- + pd.DataFrame + """ + return self.interactions.to_external(self.user_id_map, self.item_id_map, include_weight, include_datetime) diff --git a/rectools/dataset/interactions.py b/rectools/dataset/interactions.py index 3b1ccf52..fc2dd0c5 100644 --- a/rectools/dataset/interactions.py +++ b/rectools/dataset/interactions.py @@ -27,7 +27,7 @@ @attr.s(frozen=True, slots=True) class Interactions: """ - Structure to storage info about user-item interactions. + Structure to store info about user-item interactions. Usually it's more convenient to use `from_raw` method instead of direct creating. @@ -123,9 +123,7 @@ def from_raw( def get_user_item_matrix(self, include_weights: bool = True) -> sparse.csr_matrix: """ - Form an user-item CSR matrix based on `interactions` attribute. - - It is used `Interactions.get_user_item_matrix`, see its documentations for details. + Form a user-item CSR matrix based on interactions data. Parameters ---------- @@ -152,3 +150,42 @@ def get_user_item_matrix(self, include_weights: bool = True) -> sparse.csr_matri ), ) return csr + + def to_external( + self, + user_id_map: IdMap, + item_id_map: IdMap, + include_weight: bool = True, + include_datetime: bool = True, + ) -> pd.DataFrame: + """ + Convert itself to `pd.DataFrame` with replacing internal user and item ids to external ones. + + Parameters + ---------- + user_id_map : IdMap + User id map that has to be used for converting internal user ids to external ones. + item_id_map : IdMap + Item id map that has to be used for converting internal item ids to external ones. + include_weight : bool, default ``True`` + Whether to include weight column into resulting table or not + include_datetime : bool, default ``True`` + Whether to include datetime column into resulting table or not. + + Returns + ------- + pd.DataFrame + """ + res = pd.DataFrame( + { + Columns.User: user_id_map.convert_to_external(self.df[Columns.User].values), + Columns.Item: item_id_map.convert_to_external(self.df[Columns.Item].values), + } + ) + + if include_weight: + res[Columns.Weight] = self.df[Columns.Weight] + if include_datetime: + res[Columns.Datetime] = self.df[Columns.Datetime] + + return res diff --git a/tests/dataset/test_dataset.py b/tests/dataset/test_dataset.py index da28df2c..935cce98 100644 --- a/tests/dataset/test_dataset.py +++ b/tests/dataset/test_dataset.py @@ -195,3 +195,15 @@ def test_raises_when_in_sparse_features_present_ids_that_not_present_in_interact item_features_df=item_features_df, cat_item_features=["f2"], ) + + @pytest.mark.parametrize("include_weight", (True, False)) + @pytest.mark.parametrize("include_datetime", (True, False)) + def test_get_raw_interactions(self, include_weight: bool, include_datetime: bool) -> None: + dataset = Dataset.construct(self.interactions_df) + actual = dataset.get_raw_interactions(include_weight, include_datetime) + expected = self.interactions_df.astype({Columns.Weight: "float64", Columns.Datetime: "datetime64[ns]"}) + if not include_weight: + expected.drop(columns=Columns.Weight, inplace=True) + if not include_datetime: + expected.drop(columns=Columns.Datetime, inplace=True) + pd.testing.assert_frame_equal(actual, expected) diff --git a/tests/dataset/test_interactions.py b/tests/dataset/test_interactions.py index 8b4be3a4..86ec6e08 100644 --- a/tests/dataset/test_interactions.py +++ b/tests/dataset/test_interactions.py @@ -102,3 +102,55 @@ def test_raises_when_datetime_type_incorrect(self) -> None: Interactions.from_raw(df, IdMap.from_values(df[Columns.User]), IdMap.from_values(df[Columns.Item])) err_text = e.value.args[0] assert Columns.Datetime in err_text.lower() + + @pytest.mark.parametrize("include_weight", (True, False)) + @pytest.mark.parametrize("include_datetime", (True, False)) + def test_to_external(self, include_weight: bool, include_datetime: bool) -> None: + user_id_map = IdMap(np.array([10, 20, 30])) + item_id_map = IdMap(np.array(["i1", "i2"])) + interactions = Interactions(self.df) + + actual = interactions.to_external(user_id_map, item_id_map, include_weight, include_datetime) + expected = pd.DataFrame( + [ + [20, "i1"], + [30, "i2"], + [20, "i1"], + [20, "i2"], + ], + columns=Columns.UserItem, + ) + if include_weight: + expected[Columns.Weight] = self.df[Columns.Weight] + if include_datetime: + expected[Columns.Datetime] = self.df[Columns.Datetime] + + pd.testing.assert_frame_equal(actual, expected) + + def test_to_external_empty(self) -> None: + user_id_map = IdMap(np.array([10, 20, 30])) + item_id_map = IdMap(np.array(["i1", "i2"])) + interactions = Interactions(self.df.iloc[:0]) + + actual = interactions.to_external(user_id_map, item_id_map) + expected = pd.DataFrame( + [], + columns=Columns.Interactions, + ) + expected = expected.astype( + { + Columns.User: np.int64, + Columns.Item: "object", + Columns.Weight: np.float64, + Columns.Datetime: "datetime64[ns]", + } + ) + pd.testing.assert_frame_equal(actual, expected, check_index_type=False) + + def test_to_external_with_missing_ids(self) -> None: + user_id_map = IdMap(np.array([10, 20, 30])) + item_id_map = IdMap(np.array(["i1"])) + interactions = Interactions(self.df) + + with pytest.raises(KeyError): + interactions.to_external(user_id_map, item_id_map)