Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into experimental/sasrec
  • Loading branch information
blondered committed Feb 6, 2025
2 parents c8e861f + 5664a3c commit c679553
Show file tree
Hide file tree
Showing 12 changed files with 395 additions and 15 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## Unreleased

### Added
- `use_gpu` for PureSVD ([#229](https://github.com/MobileTeleSystems/RecTools/pull/229))
- `from_params` method for models and `model_from_params` function ([#252](https://github.com/MobileTeleSystems/RecTools/pull/252))


## [0.10.0] - 16.01.2025

### Added
Expand Down
119 changes: 117 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,21 @@ pytorch-lightning = {version = ">=1.6.0, <3.0.0", optional = true}
ipywidgets = {version = ">=7.7,<8.2", optional = true}
plotly = {version="^5.22.0", optional = true}
nbformat = {version = ">=4.2.0", optional = true}
cupy-cuda12x = {version = "^13.3.0", python = "<3.13", optional = true}


[tool.poetry.extras]
lightfm = ["rectools-lightfm"]
nmslib = ["nmslib", "nmslib-metabrainz"]
torch = ["torch", "pytorch-lightning"]
visuals = ["ipywidgets", "plotly", "nbformat"]
cupy = ["cupy-cuda12x"]
all = [
"rectools-lightfm",
"nmslib", "nmslib-metabrainz",
"torch", "pytorch-lightning",
"ipywidgets", "plotly", "nbformat",
"cupy-cuda12x",
]


Expand Down
4 changes: 3 additions & 1 deletion rectools/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
`models.DSSMModel`
`models.EASEModel`
`models.ImplicitALSWrapperModel`
`models.ImplicitBPRWrapperModel`
`models.ImplicitItemKNNWrapperModel`
`models.LightFMWrapperModel`
`models.PopularModel`
Expand All @@ -48,7 +49,7 @@
from .popular_in_category import PopularInCategoryModel
from .pure_svd import PureSVDModel
from .random import RandomModel
from .serialization import load_model, model_from_config
from .serialization import load_model, model_from_config, model_from_params

try:
from .lightfm import LightFMWrapperModel
Expand Down Expand Up @@ -76,4 +77,5 @@
"DSSMModel",
"load_model",
"model_from_config",
"model_from_params",
)
22 changes: 21 additions & 1 deletion rectools/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from rectools.exceptions import NotFittedError
from rectools.types import ExternalIdsArray, InternalIdsArray
from rectools.utils.config import BaseConfig
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict
from rectools.utils.serialization import PICKLE_PROTOCOL, FileLike, read_bytes

T = tp.TypeVar("T", bound="ModelBase")
Expand Down Expand Up @@ -210,6 +210,26 @@ def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self:

return cls._from_config(config_obj)

@classmethod
def from_params(cls, params: tp.Dict[str, tp.Any], sep: str = ".") -> tpe.Self:
"""
Create model from parameters.
Same as `from_config` but accepts flat dict.
Parameters
----------
params : dict
Model parameters as a flat dict with keys separated by `sep`.
sep : str, default "."
Separator for nested keys.
Returns
-------
Model instance.
"""
config_dict = unflatten_dict(params, sep=sep)
return cls.from_config(config_dict)

@classmethod
def _from_config(cls, config: ModelConfig_T) -> tpe.Self:
raise NotImplementedError()
Expand Down
41 changes: 38 additions & 3 deletions rectools/models/pure_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""SVD Model."""

import typing as tp
import warnings

import numpy as np
import typing_extensions as tpe
Expand All @@ -26,6 +27,15 @@
from rectools.models.rank import Distance
from rectools.models.vector import Factors, VectorModel

try:
import cupy as cp
from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix
from cupyx.scipy.sparse.linalg import svds as cupy_svds
except ImportError: # pragma: no cover
cupy_svds = None
cp_csr_matrix = None
cp = None


class PureSVDModelConfig(ModelConfig):
"""Config for `PureSVD` model."""
Expand All @@ -34,6 +44,7 @@ class PureSVDModelConfig(ModelConfig):
tol: float = 0
maxiter: tp.Optional[int] = None
random_state: tp.Optional[int] = None
use_gpu: tp.Optional[bool] = False
recommend_n_threads: int = 0
recommend_use_gpu_ranking: bool = True

Expand All @@ -53,7 +64,9 @@ class PureSVDModel(VectorModel[PureSVDModelConfig]):
maxiter : int, optional, default ``None``
Maximum number of iterations.
random_state : int, optional, default ``None``
Pseudorandom number generator state used to generate resamples.
Pseudorandom number generator state used to generate resamples. Omitted if use_gpu is True.
use_gpu : bool, default ``False``
If ``True``, `cupyx.scipy.sparse.linalg.svds()` is used instead of SciPy. CuPy is required.
verbose : int, default ``0``
Degree of verbose output. If ``0``, no output will be provided.
recommend_n_threads: int, default 0
Expand Down Expand Up @@ -83,6 +96,7 @@ def __init__(
tol: float = 0,
maxiter: tp.Optional[int] = None,
random_state: tp.Optional[int] = None,
use_gpu: tp.Optional[bool] = False,
verbose: int = 0,
recommend_n_threads: int = 0,
recommend_use_gpu_ranking: bool = True,
Expand All @@ -93,6 +107,16 @@ def __init__(
self.tol = tol
self.maxiter = maxiter
self.random_state = random_state
self._use_gpu = use_gpu # for making a config
if use_gpu: # pragma: no cover
if not cp:
warnings.warn("Forced to use CPU. CuPy is not available.")
use_gpu = False
elif not cp.cuda.is_available():
warnings.warn("Forced to use CPU. GPU is not available.")
use_gpu = False

self.use_gpu = use_gpu
self.recommend_n_threads = recommend_n_threads
self.recommend_use_gpu_ranking = recommend_use_gpu_ranking

Expand All @@ -106,6 +130,7 @@ def _get_config(self) -> PureSVDModelConfig:
tol=self.tol,
maxiter=self.maxiter,
random_state=self.random_state,
use_gpu=self._use_gpu,
verbose=self.verbose,
recommend_n_threads=self.recommend_n_threads,
recommend_use_gpu_ranking=self.recommend_use_gpu_ranking,
Expand All @@ -118,6 +143,7 @@ def _from_config(cls, config: PureSVDModelConfig) -> tpe.Self:
tol=config.tol,
maxiter=config.maxiter,
random_state=config.random_state,
use_gpu=config.use_gpu,
verbose=config.verbose,
recommend_n_threads=config.recommend_n_threads,
recommend_use_gpu_ranking=config.recommend_use_gpu_ranking,
Expand All @@ -126,10 +152,19 @@ def _from_config(cls, config: PureSVDModelConfig) -> tpe.Self:
def _fit(self, dataset: Dataset) -> None: # type: ignore
ui_csr = dataset.get_user_item_matrix(include_weights=True)

u, sigma, vt = svds(ui_csr, k=self.factors, tol=self.tol, maxiter=self.maxiter, random_state=self.random_state)
if self.use_gpu: # pragma: no cover
ui_csr = cp_csr_matrix(ui_csr)
# To prevent IndexError, we need to subtract 1 from factors
u, sigma, vt = cupy_svds(ui_csr.toarray(), k=self.factors - 1, tol=self.tol, maxiter=self.maxiter)
u = u.get()
self.item_factors = (cp.diag(sigma) @ vt).T.get()
else:
u, sigma, vt = svds(
ui_csr, k=self.factors, tol=self.tol, maxiter=self.maxiter, random_state=self.random_state
)
self.item_factors = (np.diag(sigma) @ vt).T

self.user_factors = u
self.item_factors = (np.diag(sigma) @ vt).T

def _get_users_factors(self, dataset: Dataset) -> Factors:
return Factors(self.user_factors)
Expand Down
24 changes: 23 additions & 1 deletion rectools/models/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic import TypeAdapter

from rectools.models.base import ModelBase, ModelClass, ModelConfig
from rectools.utils.misc import unflatten_dict
from rectools.utils.serialization import FileLike, read_bytes


Expand Down Expand Up @@ -46,7 +47,7 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase:
Parameters
----------
config : ModelConfig
config : dict or ModelConfig
Model config.
Returns
Expand All @@ -64,3 +65,24 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase:
raise ValueError("`cls` must be provided in the config to load the model")

return model_cls.from_config(config)


def model_from_params(params: dict, sep: str = ".") -> ModelBase:
"""
Create model from dict of parameters.
Same as `from_config` but accepts flat dict.
Parameters
----------
params : dict
Model parameters as a flat dict with keys separated by `sep`.
sep : str, default "."
Separator for nested keys.
Returns
-------
model
Model instance.
"""
config_dict = unflatten_dict(params, sep=sep)
return model_from_config(config_dict)
31 changes: 31 additions & 0 deletions rectools/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,34 @@ def make_dict_flat(d: tp.Dict[str, tp.Any], sep: str = ".", parent_key: str = ""
else:
items.append((new_key, v))
return dict(items)


def unflatten_dict(d: tp.Dict[str, tp.Any], sep: str = ".") -> tp.Dict[str, tp.Any]:
"""
Convert a flat dict with concatenated keys back into a nested dictionary.
Parameters
----------
d : dict
Flattened dictionary.
sep : str, default "."
Separator used in flattened keys.
Returns
-------
dict
Nested dictionary.
Examples
--------
>>> unflatten_dict({'a.b': 1, 'a.c': 2, 'd': 3})
{'a': {'b': 1, 'c': 2}, 'd': 3}
"""
result: tp.Dict[str, tp.Any] = {}
for key, value in d.items():
parts = key.split(sep)
current = result
for part in parts[:-1]:
current = current.setdefault(part, {})
current[parts[-1]] = value
return result
10 changes: 10 additions & 0 deletions tests/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import timedelta
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryFile
from unittest.mock import MagicMock

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -498,6 +499,15 @@ def test_from_config_dict_with_extra_keys(self) -> None:
):
self.model_class.from_config(config)

def test_from_params(self, mocker: MagicMock) -> None:
params = {"x": 10, "verbose": 1, "sc.td": "P2DT3H"}
spy = mocker.spy(self.model_class, "from_config")
model = self.model_class.from_params(params)
spy.assert_called_once_with({"x": 10, "verbose": 1, "sc": {"td": "P2DT3H"}})
assert model.x == 10
assert model.td == timedelta(days=2, hours=3)
assert model.verbose == 1

def test_get_config_pydantic(self) -> None:
model = self.model_class(x=10, verbose=1)
config = model.get_config(mode="pydantic")
Expand Down
Loading

0 comments on commit c679553

Please sign in to comment.