diff --git a/CHANGELOG.md b/CHANGELOG.md index 69f32f1c..83ebbf97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## Unreleased + +### Added +- `use_gpu` for PureSVD ([#229](https://github.com/MobileTeleSystems/RecTools/pull/229)) +- `from_params` method for models and `model_from_params` function ([#252](https://github.com/MobileTeleSystems/RecTools/pull/252)) + + ## [0.10.0] - 16.01.2025 ### Added diff --git a/poetry.lock b/poetry.lock index d73fdc78..93d95540 100644 --- a/poetry.lock +++ b/poetry.lock @@ -529,6 +529,36 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "cupy-cuda12x" +version = "13.3.0" +description = "CuPy: NumPy & SciPy for GPU" +optional = true +python-versions = ">=3.9" +files = [ + {file = "cupy_cuda12x-13.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:674488e990998042cc54d2486d3c37cae80a12ba3787636be5a10b9446dd6914"}, + {file = "cupy_cuda12x-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:cf4a2a0864364715881b50012927e88bd7ec1e6f1de3987970870861ae5ed25e"}, + {file = "cupy_cuda12x-13.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:7c0dc8c49d271d1c03e49a5d6c8e42e8fee3114b10f269a5ecc387731d693eaa"}, + {file = "cupy_cuda12x-13.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c0cc095b9a3835fd5db66c45ed3c58ecdc5a3bb14e53e1defbfd4a0ce5c8ecdb"}, + {file = "cupy_cuda12x-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a0e3bead04e502ebde515f0343444ca3f4f7aed09cbc3a316a946cba97f2ea66"}, + {file = "cupy_cuda12x-13.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:5f11df1149c7219858b27e4c8be92cb4eaf7364c94af6b78c40dffb98050a61f"}, + {file = "cupy_cuda12x-13.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bbd0d916310391faf0d7dc9c58fff7a6dc996b67e5768199160bbceb5ebdda8c"}, + {file = "cupy_cuda12x-13.3.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:e206bd8664f0763732b6012431f484ee535bffd77a5ae95e9bfe1c7c72396625"}, + {file = "cupy_cuda12x-13.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:88ef1478f00ae252da0026e7f04f70c9bb6a2dc130ba5f1e5bc5e8069a928bf5"}, + {file = "cupy_cuda12x-13.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3a52aa49ffcc940d034f2bb39728c90e9fa83c7a49e376404507956adb6d6ec4"}, + {file = "cupy_cuda12x-13.3.0-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:3ef13f3cbc449d2a0f816594ab1fa0236e1f06ad1eaa81ad04c75e47cbeb87be"}, + {file = "cupy_cuda12x-13.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:8f5433eec3e5cd8d39e8fcb82e0fdab7c22eba8e3304fcb0b42f2ea988fef0d6"}, +] + +[package.dependencies] +fastrlock = ">=0.5" +numpy = ">=1.22,<2.3" + +[package.extras] +all = ["Cython (>=0.29.22,<3)", "optuna (>=2.0)", "scipy (>=1.7,<1.14)"] +stylecheck = ["autopep8 (==1.5.5)", "flake8 (==3.8.4)", "mypy (==1.4.1)", "pbr (==5.5.1)", "pycodestyle (==2.6.0)", "types-setuptools (==57.4.14)"] +test = ["hypothesis (>=6.37.2,<6.55.0)", "mpmath", "packaging", "pytest (>=7.2)"] + [[package]] name = "decorator" version = "5.1.1" @@ -597,6 +627,90 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "fastrlock" +version = "0.8.2" +description = "Fast, re-entrant optimistic lock implemented in Cython" +optional = true +python-versions = "*" +files = [ + {file = "fastrlock-0.8.2-cp27-cp27m-macosx_10_15_x86_64.whl", hash = "sha256:94e348c72a1fd1f8191f25ea056448e4f5a87b8fbf005b39d290dcb0581a48cd"}, + {file = "fastrlock-0.8.2-cp27-cp27m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d5595903444c854b99c42122b87edfe8a37cd698a4eae32f4fd1d2a7b6c115d"}, + {file = "fastrlock-0.8.2-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e4bbde174a0aff5f6eeba75cf8c4c5d2a316316bc21f03a0bddca0fc3659a6f3"}, + {file = "fastrlock-0.8.2-cp27-cp27mu-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7a2ccaf88ac0db153e84305d1ef0aa138cea82c6a88309066f6eaa3bc98636cd"}, + {file = "fastrlock-0.8.2-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:31a27a2edf482df72b91fe6c6438314d2c65290aa7becc55589d156c9b91f0da"}, + {file = "fastrlock-0.8.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:e9904b5b37c3e5bb4a245c56bc4b7e497da57ffb8528f4fc39af9dcb168ee2e1"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:43a241655e83e4603a152192cf022d5ca348c2f4e56dfb02e5c9c4c1a32f9cdb"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9121a894d74e65557e47e777060a495ab85f4b903e80dd73a3c940ba042920d7"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:11bbbbc526363955aeddb9eec4cee2a0012322b7b2f15b54f44454fcf4fd398a"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:27786c62a400e282756ae1b090bcd7cfa35f28270cff65a9e7b27a5327a32561"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:08315bde19d0c2e6b06593d5a418be3dc8f9b1ee721afa96867b9853fceb45cf"}, + {file = "fastrlock-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e8b49b5743ede51e0bcf6805741f39f5e0e0fd6a172ba460cb39e3097ba803bb"}, + {file = "fastrlock-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b443e73a4dfc7b6e0800ea4c13567b9694358e86f53bb2612a51c9e727cac67b"}, + {file = "fastrlock-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:b3853ed4ce522598dc886160a7bab432a093051af85891fa2f5577c1dcac8ed6"}, + {file = "fastrlock-0.8.2-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:790fc19bccbd39426060047e53629f171a44745613bf360a045e9f9c8c4a2cea"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:dbdce852e6bb66e1b8c36679d482971d69d93acf1785657522e51b7de30c3356"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d47713ffe6d4a627fbf078be9836a95ac106b4a0543e3841572c91e292a5d885"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:ea96503b918fceaf40443182742b8964d47b65c5ebdea532893cb9479620000c"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:c6bffa978793bea5e1b00e677062e53a62255439339591b70e209fa1552d5ee0"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:75c07726c8b1a52147fd7987d6baaa318c5dced1416c3f25593e40f56e10755b"}, + {file = "fastrlock-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:88f079335e9da631efa64486c8207564a7bcd0c00526bb9e842e9d5b7e50a6cc"}, + {file = "fastrlock-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4fb2e77ff04bc4beb71d63c8e064f052ce5a6ea1e001d528d4d7f4b37d736f2e"}, + {file = "fastrlock-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:b4c9083ea89ab236b06e9ef2263971db3b4b507195fc7d5eecab95828dcae325"}, + {file = "fastrlock-0.8.2-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:98195866d3a9949915935d40a88e4f1c166e82e378f622c88025f2938624a90a"}, + {file = "fastrlock-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b22ea9bf5f9fad2b0077e944a7813f91593a4f61adf8faf734a70aed3f2b3a40"}, + {file = "fastrlock-0.8.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dcc1bf0ac8a194313cf6e645e300a8a379674ceed8e0b1e910a2de3e3c28989e"}, + {file = "fastrlock-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a3dcc876050b8f5cbc0ee84ef1e7f0c1dfe7c148f10098828bc4403683c33f10"}, + {file = "fastrlock-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:685e656048b59d8dfde8c601f188ad53a4d719eb97080cafc8696cda6d75865e"}, + {file = "fastrlock-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:fb5363cf0fddd9b50525ddbf64a1e1b28ec4c6dfb28670a940cb1cf988a6786b"}, + {file = "fastrlock-0.8.2-cp35-cp35m-macosx_10_15_x86_64.whl", hash = "sha256:a74f5a92fa6e51c4f3c69b29c4662088b97be12f40652a21109605a175c81824"}, + {file = "fastrlock-0.8.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ccf39ad5702e33e4d335b48ef9d56e21619b529b7f7471b5211419f380329b62"}, + {file = "fastrlock-0.8.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:66f2662c640bb71a1016a031eea6eef9d25c2bcdf7ffd1d1ddc5a58f9a1ced04"}, + {file = "fastrlock-0.8.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:17734e2e5af4c07ddb0fb10bd484e062c22de3be6b67940b9cc6ec2f18fa61ba"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:ab91b0c36e95d42e1041a4907e3eefd06c482d53af3c7a77be7e214cc7cd4a63"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b32fdf874868326351a75b1e4c02f97e802147119ae44c52d3d9da193ec34f5b"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:2074548a335fcf7d19ebb18d9208da9e33b06f745754466a7e001d2b1c58dd19"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4fb04442b6d1e2b36c774919c6bcbe3339c61b337261d4bd57e27932589095af"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:1fed2f4797ad68e9982038423018cf08bec5f4ce9fed63a94a790773ed6a795c"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e380ec4e6d8b26e389713995a43cb7fe56baea2d25fe073d4998c4821a026211"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:25945f962c7bd808415cfde3da624d4399d4ea71ed8918538375f16bceb79e1c"}, + {file = "fastrlock-0.8.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:2c1719ddc8218b01e82fb2e82e8451bd65076cb96d7bef4477194bbb4305a968"}, + {file = "fastrlock-0.8.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:5460c5ee6ced6d61ec8cd2324ebbe793a4960c4ffa2131ffff480e3b61c99ec5"}, + {file = "fastrlock-0.8.2-cp36-cp36m-win_amd64.whl", hash = "sha256:33145acbad8317584cd64588131c7e1e286beef6280c0009b4544c91fce171d2"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:59344c1d46b7dec97d3f22f1cc930fafe8980b3c5bc9c9765c56738a5f1559e4"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b2a1c354f13f22b737621d914f3b4a8434ae69d3027a775e94b3e671756112f9"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:cf81e0278b645004388873e0a1f9e3bc4c9ab8c18e377b14ed1a544be4b18c9a"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1b15430b93d7eb3d56f6ff690d2ebecb79ed0e58248427717eba150a508d1cd7"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:067edb0a0805bf61e17a251d5046af59f6e9d2b8ad01222e0ef7a0b7937d5548"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eb31fe390f03f7ae886dcc374f1099ec88526631a4cb891d399b68181f154ff0"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:643e1e65b4f5b284427e61a894d876d10459820e93aa1e724dfb415117be24e0"}, + {file = "fastrlock-0.8.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5dfb78dd600a12f23fc0c3ec58f81336229fdc74501ecf378d1ce5b3f2f313ea"}, + {file = "fastrlock-0.8.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b8ca0fe21458457077e4cb2d81e1ebdb146a00b3e9e2db6180a773f7ea905032"}, + {file = "fastrlock-0.8.2-cp37-cp37m-win_amd64.whl", hash = "sha256:d918dfe473291e8bfd8e13223ea5cb9b317bd9f50c280923776c377f7c64b428"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:c393af77c659a38bffbca215c0bcc8629ba4299568308dd7e4ff65d62cabed39"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:73426f5eb2ecc10626c67cf86bd0af9e00d53e80e5c67d5ce8e18376d6abfa09"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:320fd55bafee3eb069cfb5d6491f811a912758387ef2193840e2663e80e16f48"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8c1c91a68926421f5ccbc82c85f83bd3ba593b121a46a1b9a554b3f0dd67a4bf"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:ad1bc61c7f6b0e58106aaab034916b6cb041757f708b07fbcdd9d6e1ac629225"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:87f4e01b042c84e6090dbc4fbe3415ddd69f6bc0130382323f9d3f1b8dd71b46"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d34546ad2e4a480b94b6797bcc5a322b3c705c4c74c3e4e545c4a3841c1b2d59"}, + {file = "fastrlock-0.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ebb32d776b61acd49f859a1d16b9e3d84e7b46d0d92aebd58acd54dc38e96664"}, + {file = "fastrlock-0.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:30bdbe4662992348132d03996700e1cf910d141d629179b967b146a22942264e"}, + {file = "fastrlock-0.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:07ed3c7b3867c05a3d6be4ced200c7767000f3431b9be6da66972822dd86e8be"}, + {file = "fastrlock-0.8.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:ddf5d247f686aec853ddcc9a1234bfcc6f57b0a0670d2ad82fc25d8ae7e6a15f"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:7269bb3fc15587b0c191eecd95831d771a7d80f0c48929e560806b038ff3066c"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:adcb9e77aa132cc6c9de2ffe7cf880a20aa8cdba21d367d1da1a412f57bddd5d"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:a3b8b5d2935403f1b4b25ae324560e94b59593a38c0d2e7b6c9872126a9622ed"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2587cedbb36c7988e707d83f0f1175c1f882f362b5ebbee25d70218ea33d220d"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:9af691a9861027181d4de07ed74f0aee12a9650ac60d0a07f4320bff84b5d95f"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:99dd6652bd6f730beadf74ef769d38c6bbd8ee6d1c15c8d138ea680b0594387f"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:4d63b6596368dab9e0cc66bf047e7182a56f33b34db141816a4f21f5bf958228"}, + {file = "fastrlock-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ff75c90663d6e8996610d435e71487daa853871ad1770dd83dc0f2fc4997241e"}, + {file = "fastrlock-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e27c3cd27fbd25e5223c5c992b300cd4ee8f0a75c6f222ce65838138d853712c"}, + {file = "fastrlock-0.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:dd961a32a7182c3891cdebca417fda67496d5d5de6ae636962254d22723bdf52"}, + {file = "fastrlock-0.8.2.tar.gz", hash = "sha256:644ec9215cf9c4df8028d8511379a15d9c1af3e16d80e47f1b6fdc6ba118356a"}, +] + [[package]] name = "filelock" version = "3.16.1" @@ -3638,7 +3752,8 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -all = ["ipywidgets", "nbformat", "nmslib", "nmslib-metabrainz", "plotly", "pytorch-lightning", "rectools-lightfm", "torch", "torch"] +all = ["cupy-cuda12x", "ipywidgets", "nbformat", "nmslib", "nmslib-metabrainz", "plotly", "pytorch-lightning", "rectools-lightfm", "torch", "torch"] +cupy = ["cupy-cuda12x"] lightfm = ["rectools-lightfm"] nmslib = ["nmslib", "nmslib-metabrainz"] torch = ["pytorch-lightning", "torch", "torch"] @@ -3647,4 +3762,4 @@ visuals = ["ipywidgets", "nbformat", "plotly"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.13" -content-hash = "0be90e416bf1a931732bf02749481c4569077fa2ee63291abd6965d604c0d04a" +content-hash = "894765cd6220c87dd20ad10a0926b39298bfbc2d215ef0091dbf0849a468b349" diff --git a/pyproject.toml b/pyproject.toml index 4b13ac67..50a9b973 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ pytorch-lightning = {version = ">=1.6.0, <3.0.0", optional = true} ipywidgets = {version = ">=7.7,<8.2", optional = true} plotly = {version="^5.22.0", optional = true} nbformat = {version = ">=4.2.0", optional = true} +cupy-cuda12x = {version = "^13.3.0", python = "<3.13", optional = true} [tool.poetry.extras] @@ -91,11 +92,13 @@ lightfm = ["rectools-lightfm"] nmslib = ["nmslib", "nmslib-metabrainz"] torch = ["torch", "pytorch-lightning"] visuals = ["ipywidgets", "plotly", "nbformat"] +cupy = ["cupy-cuda12x"] all = [ "rectools-lightfm", "nmslib", "nmslib-metabrainz", "torch", "pytorch-lightning", "ipywidgets", "plotly", "nbformat", + "cupy-cuda12x", ] diff --git a/rectools/models/__init__.py b/rectools/models/__init__.py index be69f0ce..ac40e522 100644 --- a/rectools/models/__init__.py +++ b/rectools/models/__init__.py @@ -28,6 +28,7 @@ `models.DSSMModel` `models.EASEModel` `models.ImplicitALSWrapperModel` +`models.ImplicitBPRWrapperModel` `models.ImplicitItemKNNWrapperModel` `models.LightFMWrapperModel` `models.PopularModel` @@ -48,7 +49,7 @@ from .popular_in_category import PopularInCategoryModel from .pure_svd import PureSVDModel from .random import RandomModel -from .serialization import load_model, model_from_config +from .serialization import load_model, model_from_config, model_from_params try: from .lightfm import LightFMWrapperModel @@ -76,4 +77,5 @@ "DSSMModel", "load_model", "model_from_config", + "model_from_params", ) diff --git a/rectools/models/base.py b/rectools/models/base.py index c7569918..604c3fb3 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -31,7 +31,7 @@ from rectools.exceptions import NotFittedError from rectools.types import ExternalIdsArray, InternalIdsArray from rectools.utils.config import BaseConfig -from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat +from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict from rectools.utils.serialization import PICKLE_PROTOCOL, FileLike, read_bytes T = tp.TypeVar("T", bound="ModelBase") @@ -210,6 +210,26 @@ def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self: return cls._from_config(config_obj) + @classmethod + def from_params(cls, params: tp.Dict[str, tp.Any], sep: str = ".") -> tpe.Self: + """ + Create model from parameters. + Same as `from_config` but accepts flat dict. + + Parameters + ---------- + params : dict + Model parameters as a flat dict with keys separated by `sep`. + sep : str, default "." + Separator for nested keys. + + Returns + ------- + Model instance. + """ + config_dict = unflatten_dict(params, sep=sep) + return cls.from_config(config_dict) + @classmethod def _from_config(cls, config: ModelConfig_T) -> tpe.Self: raise NotImplementedError() diff --git a/rectools/models/pure_svd.py b/rectools/models/pure_svd.py index 3d5d5de5..a0ba2153 100644 --- a/rectools/models/pure_svd.py +++ b/rectools/models/pure_svd.py @@ -15,6 +15,7 @@ """SVD Model.""" import typing as tp +import warnings import numpy as np import typing_extensions as tpe @@ -26,6 +27,15 @@ from rectools.models.rank import Distance from rectools.models.vector import Factors, VectorModel +try: + import cupy as cp + from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix + from cupyx.scipy.sparse.linalg import svds as cupy_svds +except ImportError: # pragma: no cover + cupy_svds = None + cp_csr_matrix = None + cp = None + class PureSVDModelConfig(ModelConfig): """Config for `PureSVD` model.""" @@ -34,6 +44,7 @@ class PureSVDModelConfig(ModelConfig): tol: float = 0 maxiter: tp.Optional[int] = None random_state: tp.Optional[int] = None + use_gpu: tp.Optional[bool] = False recommend_n_threads: int = 0 recommend_use_gpu_ranking: bool = True @@ -53,7 +64,9 @@ class PureSVDModel(VectorModel[PureSVDModelConfig]): maxiter : int, optional, default ``None`` Maximum number of iterations. random_state : int, optional, default ``None`` - Pseudorandom number generator state used to generate resamples. + Pseudorandom number generator state used to generate resamples. Omitted if use_gpu is True. + use_gpu : bool, default ``False`` + If ``True``, `cupyx.scipy.sparse.linalg.svds()` is used instead of SciPy. CuPy is required. verbose : int, default ``0`` Degree of verbose output. If ``0``, no output will be provided. recommend_n_threads: int, default 0 @@ -83,6 +96,7 @@ def __init__( tol: float = 0, maxiter: tp.Optional[int] = None, random_state: tp.Optional[int] = None, + use_gpu: tp.Optional[bool] = False, verbose: int = 0, recommend_n_threads: int = 0, recommend_use_gpu_ranking: bool = True, @@ -93,6 +107,16 @@ def __init__( self.tol = tol self.maxiter = maxiter self.random_state = random_state + self._use_gpu = use_gpu # for making a config + if use_gpu: # pragma: no cover + if not cp: + warnings.warn("Forced to use CPU. CuPy is not available.") + use_gpu = False + elif not cp.cuda.is_available(): + warnings.warn("Forced to use CPU. GPU is not available.") + use_gpu = False + + self.use_gpu = use_gpu self.recommend_n_threads = recommend_n_threads self.recommend_use_gpu_ranking = recommend_use_gpu_ranking @@ -106,6 +130,7 @@ def _get_config(self) -> PureSVDModelConfig: tol=self.tol, maxiter=self.maxiter, random_state=self.random_state, + use_gpu=self._use_gpu, verbose=self.verbose, recommend_n_threads=self.recommend_n_threads, recommend_use_gpu_ranking=self.recommend_use_gpu_ranking, @@ -118,6 +143,7 @@ def _from_config(cls, config: PureSVDModelConfig) -> tpe.Self: tol=config.tol, maxiter=config.maxiter, random_state=config.random_state, + use_gpu=config.use_gpu, verbose=config.verbose, recommend_n_threads=config.recommend_n_threads, recommend_use_gpu_ranking=config.recommend_use_gpu_ranking, @@ -126,10 +152,19 @@ def _from_config(cls, config: PureSVDModelConfig) -> tpe.Self: def _fit(self, dataset: Dataset) -> None: # type: ignore ui_csr = dataset.get_user_item_matrix(include_weights=True) - u, sigma, vt = svds(ui_csr, k=self.factors, tol=self.tol, maxiter=self.maxiter, random_state=self.random_state) + if self.use_gpu: # pragma: no cover + ui_csr = cp_csr_matrix(ui_csr) + # To prevent IndexError, we need to subtract 1 from factors + u, sigma, vt = cupy_svds(ui_csr.toarray(), k=self.factors - 1, tol=self.tol, maxiter=self.maxiter) + u = u.get() + self.item_factors = (cp.diag(sigma) @ vt).T.get() + else: + u, sigma, vt = svds( + ui_csr, k=self.factors, tol=self.tol, maxiter=self.maxiter, random_state=self.random_state + ) + self.item_factors = (np.diag(sigma) @ vt).T self.user_factors = u - self.item_factors = (np.diag(sigma) @ vt).T def _get_users_factors(self, dataset: Dataset) -> Factors: return Factors(self.user_factors) diff --git a/rectools/models/serialization.py b/rectools/models/serialization.py index 91844187..f4ce3c58 100644 --- a/rectools/models/serialization.py +++ b/rectools/models/serialization.py @@ -18,6 +18,7 @@ from pydantic import TypeAdapter from rectools.models.base import ModelBase, ModelClass, ModelConfig +from rectools.utils.misc import unflatten_dict from rectools.utils.serialization import FileLike, read_bytes @@ -46,7 +47,7 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase: Parameters ---------- - config : ModelConfig + config : dict or ModelConfig Model config. Returns @@ -64,3 +65,24 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase: raise ValueError("`cls` must be provided in the config to load the model") return model_cls.from_config(config) + + +def model_from_params(params: dict, sep: str = ".") -> ModelBase: + """ + Create model from dict of parameters. + Same as `from_config` but accepts flat dict. + + Parameters + ---------- + params : dict + Model parameters as a flat dict with keys separated by `sep`. + sep : str, default "." + Separator for nested keys. + + Returns + ------- + model + Model instance. + """ + config_dict = unflatten_dict(params, sep=sep) + return model_from_config(config_dict) diff --git a/rectools/utils/misc.py b/rectools/utils/misc.py index 3e6ba433..1d81da14 100644 --- a/rectools/utils/misc.py +++ b/rectools/utils/misc.py @@ -228,3 +228,34 @@ def make_dict_flat(d: tp.Dict[str, tp.Any], sep: str = ".", parent_key: str = "" else: items.append((new_key, v)) return dict(items) + + +def unflatten_dict(d: tp.Dict[str, tp.Any], sep: str = ".") -> tp.Dict[str, tp.Any]: + """ + Convert a flat dict with concatenated keys back into a nested dictionary. + + Parameters + ---------- + d : dict + Flattened dictionary. + sep : str, default "." + Separator used in flattened keys. + + Returns + ------- + dict + Nested dictionary. + + Examples + -------- + >>> unflatten_dict({'a.b': 1, 'a.c': 2, 'd': 3}) + {'a': {'b': 1, 'c': 2}, 'd': 3} + """ + result: tp.Dict[str, tp.Any] = {} + for key, value in d.items(): + parts = key.split(sep) + current = result + for part in parts[:-1]: + current = current.setdefault(part, {}) + current[parts[-1]] = value + return result diff --git a/tests/models/test_base.py b/tests/models/test_base.py index 5ab7d68e..d281017f 100644 --- a/tests/models/test_base.py +++ b/tests/models/test_base.py @@ -19,6 +19,7 @@ from datetime import timedelta from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryFile +from unittest.mock import MagicMock import numpy as np import pandas as pd @@ -498,6 +499,15 @@ def test_from_config_dict_with_extra_keys(self) -> None: ): self.model_class.from_config(config) + def test_from_params(self, mocker: MagicMock) -> None: + params = {"x": 10, "verbose": 1, "sc.td": "P2DT3H"} + spy = mocker.spy(self.model_class, "from_config") + model = self.model_class.from_params(params) + spy.assert_called_once_with({"x": 10, "verbose": 1, "sc": {"td": "P2DT3H"}}) + assert model.x == 10 + assert model.td == timedelta(days=2, hours=3) + assert model.verbose == 1 + def test_get_config_pydantic(self) -> None: model = self.model_class(x=10, verbose=1) config = model.get_config(mode="pydantic") diff --git a/tests/models/test_pure_svd.py b/tests/models/test_pure_svd.py index 71a7ed67..408023da 100644 --- a/tests/models/test_pure_svd.py +++ b/tests/models/test_pure_svd.py @@ -17,6 +17,7 @@ import numpy as np import pandas as pd import pytest +from pytest_mock import MockerFixture from rectools import Columns from rectools.dataset import Dataset @@ -32,6 +33,17 @@ assert_second_fit_refits_model, ) +try: + import cupy as cp # pylint: disable=import-error, unused-import +except ImportError: # pragma: no cover + cp = None + +try: + HAS_CUDA = cp.is_available() if cp else False +except Exception: # pragma: no cover # pylint: disable=broad-except + # If CUDA isn't installed cupy raises CUDARuntimeError: + HAS_CUDA = False + class TestPureSVDModel: @@ -72,6 +84,7 @@ def test_basic( expected: pd.DataFrame, use_gpu_ranking: bool, ) -> None: + model = PureSVDModel(factors=2, recommend_use_gpu_ranking=use_gpu_ranking).fit(dataset) actual = model.recommend( users=np.array([10, 20]), @@ -85,6 +98,52 @@ def test_basic( actual, ) + # SciPy's svds and cupy's svds results can be different and use_gpu fallback causes errors + @pytest.mark.skipif(cp is None or not HAS_CUDA, reason="CUDA is not available") + @pytest.mark.parametrize( + "filter_viewed,expected", + ( + ( + True, + pd.DataFrame( + { + Columns.User: [10, 10, 20, 20], + Columns.Item: [15, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2], + } + ), + ), + ( + False, + pd.DataFrame( + { + Columns.User: [10, 10, 20, 20], + Columns.Item: [11, 12, 11, 12], + Columns.Rank: [1, 2, 1, 2], + } + ), + ), + ), + ) + def test_basic_gpu( + self, + dataset: Dataset, + filter_viewed: bool, + expected: pd.DataFrame, + ) -> None: + model = PureSVDModel(factors=2, use_gpu=True, recommend_use_gpu_ranking=True).fit(dataset) + actual = model.recommend( + users=np.array([10, 20]), + dataset=dataset, + k=2, + filter_viewed=filter_viewed, + ) + pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected) + pd.testing.assert_frame_equal( + actual.sort_values([Columns.User, Columns.Score], ascending=[True, False]).reset_index(drop=True), + actual, + ) + @pytest.mark.parametrize( "filter_viewed,expected", ( @@ -279,12 +338,16 @@ def test_dumps_loads(self, dataset: Dataset) -> None: class TestPureSVDModelConfiguration: - def test_from_config(self) -> None: + @pytest.mark.parametrize("use_gpu", (False, True)) + def test_from_config(self, mocker: MockerFixture, use_gpu: bool) -> None: + mocker.patch("rectools.models.pure_svd.cp", return_value=True) + mocker.patch("rectools.models.pure_svd.cp.cuda.is_available", return_value=True) config = { "factors": 100, "tol": 0, "maxiter": 100, "random_state": 32, + "use_gpu": use_gpu, "verbose": 0, } model = PureSVDModel.from_config(config) @@ -296,12 +359,18 @@ def test_from_config(self) -> None: @pytest.mark.parametrize("random_state", (None, 42)) @pytest.mark.parametrize("simple_types", (False, True)) - def test_get_config(self, random_state: tp.Optional[int], simple_types: bool) -> None: + @pytest.mark.parametrize("use_gpu", (False, True)) + def test_get_config( + self, mocker: MockerFixture, random_state: tp.Optional[int], simple_types: bool, use_gpu: bool + ) -> None: + mocker.patch("rectools.models.pure_svd.cp.cuda.is_available", return_value=True) + mocker.patch("rectools.models.pure_svd.cp", return_value=True) model = PureSVDModel( factors=100, - tol=1, + tol=1.0, maxiter=100, random_state=random_state, + use_gpu=use_gpu, verbose=1, recommend_n_threads=2, recommend_use_gpu_ranking=False, @@ -310,9 +379,10 @@ def test_get_config(self, random_state: tp.Optional[int], simple_types: bool) -> expected = { "cls": "PureSVDModel" if simple_types else PureSVDModel, "factors": 100, - "tol": 1, + "tol": 1.0, "maxiter": 100, "random_state": random_state, + "use_gpu": use_gpu, "verbose": 1, "recommend_n_threads": 2, "recommend_use_gpu_ranking": False, diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index ce95af17..f3340707 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -14,6 +14,7 @@ import typing as tp from tempfile import NamedTemporaryFile +from unittest.mock import MagicMock import pytest from implicit.als import AlternatingLeastSquares @@ -26,7 +27,6 @@ except ImportError: LightFM = object # it's ok in case we're skipping the tests - from rectools.metrics import NDCG from rectools.models import ( DSSMModel, @@ -39,10 +39,13 @@ PopularModel, load_model, model_from_config, + model_from_params, + serialization, ) from rectools.models.base import ModelBase, ModelConfig from rectools.models.nn.transformer_base import TransformerModelBase from rectools.models.vector import VectorModel +from rectools.utils.config import BaseConfig from .utils import get_successors @@ -79,20 +82,26 @@ def test_load_model(model_cls: tp.Type[ModelBase]) -> None: assert not loaded_model.is_fitted +class CustomModelSubConfig(BaseConfig): + x: int = 10 + + class CustomModelConfig(ModelConfig): some_param: int = 1 + sc: CustomModelSubConfig = CustomModelSubConfig() class CustomModel(ModelBase[CustomModelConfig]): config_class = CustomModelConfig - def __init__(self, some_param: int = 1, verbose: int = 0): + def __init__(self, some_param: int = 1, x: int = 10, verbose: int = 0): super().__init__(verbose=verbose) self.some_param = some_param + self.x = x @classmethod def _from_config(cls, config: CustomModelConfig) -> "CustomModel": - return cls(some_param=config.some_param, verbose=config.verbose) + return cls(some_param=config.some_param, x=config.sc.x, verbose=config.verbose) class TestModelFromConfig: @@ -121,6 +130,7 @@ def test_custom_model_creation(self, config: tp.Union[dict, CustomModelConfig]) model = model_from_config(config) assert isinstance(model, CustomModel) assert model.some_param == 2 + assert model.x == 10 @pytest.mark.parametrize("simple_types", (False, True)) def test_fails_on_missing_cls(self, simple_types: bool) -> None: @@ -179,3 +189,15 @@ def test_fails_on_model_cls_without_from_config_support(self, model_cls: tp.Any) config = {"cls": model_cls} with pytest.raises(NotImplementedError, match="`from_config` method is not implemented for `DSSMModel` model"): model_from_config(config) + + +class TestModelFromParams: + def test_uses_from_config(self, mocker: MagicMock) -> None: + params = {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2, "sc.x": 20} + spy = mocker.spy(serialization, "model_from_config") + model = model_from_params(params) + expected_config = {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2, "sc": {"x": 20}} + spy.assert_called_once_with(expected_config) + assert isinstance(model, CustomModel) + assert model.some_param == 2 + assert model.x == 20 diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py new file mode 100644 index 00000000..eb493f7e --- /dev/null +++ b/tests/utils/test_misc.py @@ -0,0 +1,42 @@ +from rectools.utils.misc import unflatten_dict + + +class TestUnflattenDict: + def test_empty(self) -> None: + assert unflatten_dict({}) == {} + + def test_complex(self) -> None: + flattened = { + "a.b": 1, + "a.c": 2, + "d": 3, + "a.e.f": [10, 20], + } + excepted = { + "a": {"b": 1, "c": 2, "e": {"f": [10, 20]}}, + "d": 3, + } + assert unflatten_dict(flattened) == excepted + + def test_simple(self) -> None: + flattened = { + "a": 1, + "b": 2, + } + excepted = { + "a": 1, + "b": 2, + } + assert unflatten_dict(flattened) == excepted + + def test_non_default_sep(self) -> None: + flattened = { + "a_b": 1, + "a_c": 2, + "d": 3, + } + excepted = { + "a": {"b": 1, "c": 2}, + "d": 3, + } + assert unflatten_dict(flattened, sep="_") == excepted