Skip to content

Commit

Permalink
Feature/model from params (#252)
Browse files Browse the repository at this point in the history
- Added `from_params` method for models and `model_from_params` function
  • Loading branch information
feldlime authored Feb 4, 2025
1 parent fa6c201 commit 5664a3c
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 7 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## Unreleased
- Add `use_gpu` for PureSVD ([#229](https://github.com/MobileTeleSystems/RecTools/pull/229))

### Added
- `use_gpu` for PureSVD ([#229](https://github.com/MobileTeleSystems/RecTools/pull/229))
- `from_params` method for models and `model_from_params` function ([#252](https://github.com/MobileTeleSystems/RecTools/pull/252))


## [0.10.0] - 16.01.2025

Expand Down
4 changes: 3 additions & 1 deletion rectools/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
`models.DSSMModel`
`models.EASEModel`
`models.ImplicitALSWrapperModel`
`models.ImplicitBPRWrapperModel`
`models.ImplicitItemKNNWrapperModel`
`models.LightFMWrapperModel`
`models.PopularModel`
Expand All @@ -44,7 +45,7 @@
from .popular_in_category import PopularInCategoryModel
from .pure_svd import PureSVDModel
from .random import RandomModel
from .serialization import load_model, model_from_config
from .serialization import load_model, model_from_config, model_from_params

try:
from .lightfm import LightFMWrapperModel
Expand All @@ -70,4 +71,5 @@
"DSSMModel",
"load_model",
"model_from_config",
"model_from_params",
)
22 changes: 21 additions & 1 deletion rectools/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from rectools.exceptions import NotFittedError
from rectools.types import ExternalIdsArray, InternalIdsArray
from rectools.utils.config import BaseConfig
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict
from rectools.utils.serialization import PICKLE_PROTOCOL, FileLike, read_bytes

T = tp.TypeVar("T", bound="ModelBase")
Expand Down Expand Up @@ -210,6 +210,26 @@ def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self:

return cls._from_config(config_obj)

@classmethod
def from_params(cls, params: tp.Dict[str, tp.Any], sep: str = ".") -> tpe.Self:
"""
Create model from parameters.
Same as `from_config` but accepts flat dict.
Parameters
----------
params : dict
Model parameters as a flat dict with keys separated by `sep`.
sep : str, default "."
Separator for nested keys.
Returns
-------
Model instance.
"""
config_dict = unflatten_dict(params, sep=sep)
return cls.from_config(config_dict)

@classmethod
def _from_config(cls, config: ModelConfig_T) -> tpe.Self:
raise NotImplementedError()
Expand Down
24 changes: 23 additions & 1 deletion rectools/models/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic import TypeAdapter

from rectools.models.base import ModelBase, ModelClass, ModelConfig
from rectools.utils.misc import unflatten_dict
from rectools.utils.serialization import FileLike, read_bytes


Expand Down Expand Up @@ -46,7 +47,7 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase:
Parameters
----------
config : ModelConfig
config : dict or ModelConfig
Model config.
Returns
Expand All @@ -64,3 +65,24 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase:
raise ValueError("`cls` must be provided in the config to load the model")

return model_cls.from_config(config)


def model_from_params(params: dict, sep: str = ".") -> ModelBase:
"""
Create model from dict of parameters.
Same as `from_config` but accepts flat dict.
Parameters
----------
params : dict
Model parameters as a flat dict with keys separated by `sep`.
sep : str, default "."
Separator for nested keys.
Returns
-------
model
Model instance.
"""
config_dict = unflatten_dict(params, sep=sep)
return model_from_config(config_dict)
31 changes: 31 additions & 0 deletions rectools/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,34 @@ def make_dict_flat(d: tp.Dict[str, tp.Any], sep: str = ".", parent_key: str = ""
else:
items.append((new_key, v))
return dict(items)


def unflatten_dict(d: tp.Dict[str, tp.Any], sep: str = ".") -> tp.Dict[str, tp.Any]:
"""
Convert a flat dict with concatenated keys back into a nested dictionary.
Parameters
----------
d : dict
Flattened dictionary.
sep : str, default "."
Separator used in flattened keys.
Returns
-------
dict
Nested dictionary.
Examples
--------
>>> unflatten_dict({'a.b': 1, 'a.c': 2, 'd': 3})
{'a': {'b': 1, 'c': 2}, 'd': 3}
"""
result: tp.Dict[str, tp.Any] = {}
for key, value in d.items():
parts = key.split(sep)
current = result
for part in parts[:-1]:
current = current.setdefault(part, {})
current[parts[-1]] = value
return result
10 changes: 10 additions & 0 deletions tests/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import timedelta
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryFile
from unittest.mock import MagicMock

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -498,6 +499,15 @@ def test_from_config_dict_with_extra_keys(self) -> None:
):
self.model_class.from_config(config)

def test_from_params(self, mocker: MagicMock) -> None:
params = {"x": 10, "verbose": 1, "sc.td": "P2DT3H"}
spy = mocker.spy(self.model_class, "from_config")
model = self.model_class.from_params(params)
spy.assert_called_once_with({"x": 10, "verbose": 1, "sc": {"td": "P2DT3H"}})
assert model.x == 10
assert model.td == timedelta(days=2, hours=3)
assert model.verbose == 1

def test_get_config_pydantic(self) -> None:
model = self.model_class(x=10, verbose=1)
config = model.get_config(mode="pydantic")
Expand Down
28 changes: 25 additions & 3 deletions tests/models/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import typing as tp
from tempfile import NamedTemporaryFile
from unittest.mock import MagicMock

import pytest
from implicit.als import AlternatingLeastSquares
Expand All @@ -26,7 +27,6 @@
except ImportError:
LightFM = object # it's ok in case we're skipping the tests


from rectools.metrics import NDCG
from rectools.models import (
DSSMModel,
Expand All @@ -39,9 +39,12 @@
PopularModel,
load_model,
model_from_config,
model_from_params,
serialization,
)
from rectools.models.base import ModelBase, ModelConfig
from rectools.models.vector import VectorModel
from rectools.utils.config import BaseConfig

from .utils import get_successors

Expand Down Expand Up @@ -77,20 +80,26 @@ def test_load_model(model_cls: tp.Type[ModelBase]) -> None:
assert isinstance(loaded_model, model_cls)


class CustomModelSubConfig(BaseConfig):
x: int = 10


class CustomModelConfig(ModelConfig):
some_param: int = 1
sc: CustomModelSubConfig = CustomModelSubConfig()


class CustomModel(ModelBase[CustomModelConfig]):
config_class = CustomModelConfig

def __init__(self, some_param: int = 1, verbose: int = 0):
def __init__(self, some_param: int = 1, x: int = 10, verbose: int = 0):
super().__init__(verbose=verbose)
self.some_param = some_param
self.x = x

@classmethod
def _from_config(cls, config: CustomModelConfig) -> "CustomModel":
return cls(some_param=config.some_param, verbose=config.verbose)
return cls(some_param=config.some_param, x=config.sc.x, verbose=config.verbose)


class TestModelFromConfig:
Expand Down Expand Up @@ -119,6 +128,7 @@ def test_custom_model_creation(self, config: tp.Union[dict, CustomModelConfig])
model = model_from_config(config)
assert isinstance(model, CustomModel)
assert model.some_param == 2
assert model.x == 10

@pytest.mark.parametrize("simple_types", (False, True))
def test_fails_on_missing_cls(self, simple_types: bool) -> None:
Expand Down Expand Up @@ -177,3 +187,15 @@ def test_fails_on_model_cls_without_from_config_support(self, model_cls: tp.Any)
config = {"cls": model_cls}
with pytest.raises(NotImplementedError, match="`from_config` method is not implemented for `DSSMModel` model"):
model_from_config(config)


class TestModelFromParams:
def test_uses_from_config(self, mocker: MagicMock) -> None:
params = {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2, "sc.x": 20}
spy = mocker.spy(serialization, "model_from_config")
model = model_from_params(params)
expected_config = {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2, "sc": {"x": 20}}
spy.assert_called_once_with(expected_config)
assert isinstance(model, CustomModel)
assert model.some_param == 2
assert model.x == 20
42 changes: 42 additions & 0 deletions tests/utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from rectools.utils.misc import unflatten_dict


class TestUnflattenDict:
def test_empty(self) -> None:
assert unflatten_dict({}) == {}

def test_complex(self) -> None:
flattened = {
"a.b": 1,
"a.c": 2,
"d": 3,
"a.e.f": [10, 20],
}
excepted = {
"a": {"b": 1, "c": 2, "e": {"f": [10, 20]}},
"d": 3,
}
assert unflatten_dict(flattened) == excepted

def test_simple(self) -> None:
flattened = {
"a": 1,
"b": 2,
}
excepted = {
"a": 1,
"b": 2,
}
assert unflatten_dict(flattened) == excepted

def test_non_default_sep(self) -> None:
flattened = {
"a_b": 1,
"a_c": 2,
"d": 3,
}
excepted = {
"a": {"b": 1, "c": 2},
"d": 3,
}
assert unflatten_dict(flattened, sep="_") == excepted

0 comments on commit 5664a3c

Please sign in to comment.