Skip to content

Commit 60f33d9

Browse files
committed
feat: add extras mechanism to finer-grained dependency selection
Inspired by: #36 (comment). This commit includes the extra-dependencies mechanism of setuptools to overcome limitations specific to certain dependencies (e.g. no support for some Python interpreter versions). The changes use the following conventions for extras names: - `[all]`: install all dependencies from all extras - `[X-sampler]`: install all dependencies to make X sampler to work - `[X-loss]`: install all dependencies to make X loss function to work. We do not have yet an example for the last item for the moment; but for "forward-compatibility" of the nomenclature, we leave the -sampler suffix. E.g. for GPy, we could have the extra called gp-sampler, that installs GPy on-demand, and not installed if not needed by the user. This commit also includes a mechanism to handle import errors for the non-installed dependencies for some component. Such mechanism provides a useful message to the user, e.g. it raises an exception with a useful error message pointing out to the missing extra in its local installation of black-it.
1 parent 5db2501 commit 60f33d9

11 files changed

+241
-10
lines changed

README.md

+18-3
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,34 @@ matter of days, with no need to reimplement all the plumbings from scratch.
3838

3939
This project requires Python v3.8 or later.
4040

41-
To install the latest version of the package from [PyPI](https://pypi.org/project/black-it/):
41+
To install the latest version of the package from [PyPI](https://pypi.org/project/black-it/), with all the extra dependencies (recommended):
4242
```
43-
pip install black-it
43+
pip install "black-it[all]"
4444
```
4545

4646
Or, directly from GitHub:
4747

4848
```
49-
pip install git+https://github.com/bancaditalia/black-it.git#egg=black-it
49+
pip install git+https://github.com/bancaditalia/black-it.git#egg="black-it[all]"
5050
```
5151

5252
If you'd like to contribute to the package, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) guide.
5353

54+
### Feature-specific Package Dependencies
55+
56+
We use the [optional dependencies mechanism of `setuptools`](https://setuptools.pypa.io/en/latest/userguide/dependency_management.html#optional-dependencies)
57+
(also called _extras_) to allow users to avoid dependencies for features they don't use.
58+
59+
For the basic features of the package, you can install the `black-it` package without extras, e.g. `pip install black-it`.
60+
However, for certain components, you will need to install some more extras using the syntax `pip install black-it[extra-1,extra-2,...]`.
61+
62+
For example, the [Gaussian Process Sampler](https://bancaditalia.github.io/black-it/samplers/#black_it.samplers.gaussian_process.GaussianProcessSampler)
63+
depends on the Python package [`GPy`](https://github.com/SheffieldML/GPy/).
64+
If the Gaussian Process sampler is not needed by your application, you can avoid its installation by just installing `black-it` as explained above.
65+
However, if you need the sampler, you must install `black-it` with the `gp-sampler` extra: `pip install black-it[gp-sampler]`.
66+
67+
The special extra `all` will install all the dependencies.
68+
5469
## Quick Example
5570

5671
The GitHub repo of Black-it contains a series ready-to-run calibration examples.

black_it/_load_dependency.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Black-box ABM Calibration Kit (Black-it)
2+
# Copyright (C) 2021-2023 Banca d'Italia
3+
#
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the GNU Affero General Public License as
6+
# published by the Free Software Foundation, either version 3 of the
7+
# License, or (at your option) any later version.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
# GNU Affero General Public License for more details.
13+
#
14+
# You should have received a copy of the GNU Affero General Public License
15+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
17+
"""
18+
Python module to handle extras dependencies loading and import errors.
19+
20+
This is a private module of the library. There should be no point in using it directly from client code.
21+
"""
22+
23+
import sys
24+
from typing import Optional
25+
26+
# known extras and their dependencies
27+
_GPY_PACKAGE_NAME = "GPy"
28+
_GP_SAMPLER_EXTRA_NAME = "gp-sampler"
29+
30+
_XGBOOST_PACKAGE_NAME = "xgboost"
31+
_XGBOOST_SAMPLER_EXTRA_NAME = "xgboost-sampler"
32+
33+
34+
class DependencyNotInstalled(Exception):
35+
"""Library exception for when a required dependency is not installed."""
36+
37+
def __init__(self, component_name: str, package_name: str, extra_name: str) -> None:
38+
"""Initialize the exception object."""
39+
message = (
40+
f"Cannot import package '{package_name}', required by component {component_name}. "
41+
f"To solve the issue, you can install the extra '{extra_name}': pip install black-it[{extra_name}]"
42+
)
43+
super().__init__(message)
44+
45+
46+
class GPyNotSupportedOnPy311Exception(Exception):
47+
"""Specific exception class for import error of GPy on Python 3.11."""
48+
49+
__ERROR_MSG = (
50+
f"The GaussianProcessSampler depends on '{_GPY_PACKAGE_NAME}', which is not supported on Python 3.11; "
51+
f"see https://github.com/bancaditalia/black-it/issues/36"
52+
)
53+
54+
def __init__(self) -> None:
55+
"""Initialize the exception object."""
56+
super().__init__(self.__ERROR_MSG)
57+
58+
59+
def _check_import_error_else_raise_exception(
60+
import_error: Optional[ImportError],
61+
component_name: str,
62+
package_name: str,
63+
black_it_extra_name: str,
64+
) -> None:
65+
"""
66+
Check an import error; raise the DependencyNotInstalled exception with a useful message.
67+
68+
Args:
69+
import_error: the ImportError object generated by the failed attempt. If None, then no error occurred.
70+
component_name: the component for which the dependency is needed
71+
package_name: the Python package name of the dependency
72+
black_it_extra_name: the name of the black-it extra to install to solve the issue.
73+
"""
74+
if import_error is None:
75+
# nothing to do.
76+
return
77+
78+
# an import error happened; we need to raise error to the caller
79+
raise DependencyNotInstalled(component_name, package_name, black_it_extra_name)
80+
81+
82+
def _check_gpy_import_error_else_raise_exception(
83+
import_error: Optional[ImportError],
84+
component_name: str,
85+
package_name: str,
86+
black_it_extra_name: str,
87+
) -> None:
88+
"""
89+
Check GPy import error and if an error occurred, raise erorr with a useful error message.
90+
91+
We need to handle two cases:
92+
93+
- the user is using Python 3.11: the GPy package cannot be installed there;
94+
see https://github.com/SheffieldML/GPy/issues/998
95+
- the user did not install the 'gp-sampler' extra.
96+
97+
Args:
98+
import_error: the ImportError object generated by the failed attempt. If None, then no error occurred.
99+
component_name: the component for which the dependency is needed
100+
package_name: the Python package name of the dependency
101+
black_it_extra_name: the name of the black-it extra to install to solve the issue.
102+
"""
103+
if import_error is None:
104+
# nothing to do.
105+
return
106+
107+
if sys.version_info == (3, 11):
108+
raise GPyNotSupportedOnPy311Exception()
109+
110+
_check_import_error_else_raise_exception(
111+
import_error, component_name, package_name, black_it_extra_name
112+
)

black_it/samplers/gaussian_process.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,26 @@
2020
from enum import Enum
2121
from typing import Optional, Tuple, cast
2222

23-
import GPy
2423
import numpy as np
25-
from GPy.models import GPRegression
2624
from numpy.typing import NDArray
2725
from scipy.special import erfc # pylint: disable=no-name-in-module
2826

27+
from black_it._load_dependency import (
28+
_GP_SAMPLER_EXTRA_NAME,
29+
_GPY_PACKAGE_NAME,
30+
_check_gpy_import_error_else_raise_exception,
31+
)
2932
from black_it.samplers.surrogate import MLSurrogateSampler
3033

34+
_GPY_IMPORT_ERROR: Optional[ImportError]
35+
try:
36+
import GPy
37+
from GPy.models import GPRegression
38+
except ImportError as e:
39+
_GPY_IMPORT_ERROR = e
40+
else:
41+
_GPY_IMPORT_ERROR = None
42+
3143

3244
class _AcquisitionTypes(Enum):
3345
"""Enumeration of allowed acquisition types."""
@@ -71,6 +83,8 @@ def __init__( # pylint: disable=too-many-arguments
7183
optimize_restarts: number of independent random trials of the optimization of the GP hyperparameters
7284
acquisition: type of acquisition function, it can be 'expected_improvement' of simply 'mean'
7385
"""
86+
self.__check_gpy_import_error()
87+
7488
self._validate_acquisition(acquisition)
7589

7690
super().__init__(
@@ -81,6 +95,13 @@ def __init__( # pylint: disable=too-many-arguments
8195
self.acquisition = acquisition
8296
self._gpmodel: Optional[GPRegression] = None
8397

98+
@classmethod
99+
def __check_gpy_import_error(cls) -> None:
100+
"""Check if an import error happened while attempting to import the 'GPy' package."""
101+
_check_gpy_import_error_else_raise_exception(
102+
_GPY_IMPORT_ERROR, cls.__name__, _GPY_PACKAGE_NAME, _GP_SAMPLER_EXTRA_NAME
103+
)
104+
84105
@staticmethod
85106
def _validate_acquisition(acquisition: str) -> None:
86107
"""
@@ -94,12 +115,12 @@ def _validate_acquisition(acquisition: str) -> None:
94115
"""
95116
try:
96117
_AcquisitionTypes(acquisition)
97-
except ValueError as e:
118+
except ValueError as exp:
98119
raise ValueError(
99120
"expected one of the following acquisition types: "
100121
f"[{' '.join(map(str, _AcquisitionTypes))}], "
101122
f"got {acquisition}"
102-
) from e
123+
) from exp
103124

104125
def fit(self, X: NDArray[np.float64], y: NDArray[np.float64]) -> None:
105126
"""Fit a gaussian process surrogate model."""

black_it/samplers/xgboost.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,27 @@
1919
from typing import Optional, cast
2020

2121
import numpy as np
22-
import xgboost as xgb
2322
from numpy.typing import NDArray
2423

24+
from black_it._load_dependency import (
25+
_XGBOOST_PACKAGE_NAME,
26+
_XGBOOST_SAMPLER_EXTRA_NAME,
27+
_check_import_error_else_raise_exception,
28+
)
2529
from black_it.samplers.surrogate import MLSurrogateSampler
2630

2731
MAX_FLOAT32 = np.finfo(np.float32).max
2832
MIN_FLOAT32 = np.finfo(np.float32).min
2933
EPS_FLOAT32 = np.finfo(np.float32).eps
3034

35+
_XGBOOST_IMPORT_ERROR: Optional[ImportError]
36+
try:
37+
import xgboost as xgb
38+
except ImportError as e:
39+
_XGBOOST_IMPORT_ERROR = e
40+
else:
41+
_XGBOOST_IMPORT_ERROR = None
42+
3143

3244
class XGBoostSampler(MLSurrogateSampler):
3345
"""This class implements xgboost sampling."""
@@ -64,6 +76,7 @@ def __init__( # pylint: disable=too-many-arguments
6476
References:
6577
Lamperti, Roventini, and Sani, "Agent-based model calibration using machine learning surrogates"
6678
"""
79+
self.__check_xgboost_import_error()
6780
super().__init__(
6881
batch_size, random_state, max_deduplication_passes, candidate_pool_size
6982
)
@@ -75,6 +88,16 @@ def __init__( # pylint: disable=too-many-arguments
7588
self._n_estimators = n_estimators
7689
self._xg_regressor: Optional[xgb.XGBRegressor] = None
7790

91+
@classmethod
92+
def __check_xgboost_import_error(cls) -> None:
93+
"""Check if an import error happened while attempting to import the 'xgboost' package."""
94+
_check_import_error_else_raise_exception(
95+
_XGBOOST_IMPORT_ERROR,
96+
cls.__name__,
97+
_XGBOOST_PACKAGE_NAME,
98+
_XGBOOST_SAMPLER_EXTRA_NAME,
99+
)
100+
78101
@property
79102
def colsample_bytree(self) -> float:
80103
"""Get the colsample_bytree parameter."""

poetry.lock

+6-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+8
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ tox = "^4.4.7"
9191
twine = "^4.0.0"
9292
vulture = "^2.3"
9393

94+
GPy = { version = "^1.10.0", optional = true }
95+
xgboost = { version = "^1.7.2", optional = true }
96+
97+
[tool.poetry.extras]
98+
gp-sampler = ["GPy"]
99+
xgboost-sampler = ["xgboost"]
100+
all = ["GPy", "xgboost"]
101+
94102
[build-system]
95103
requires = ["poetry-core>=1.0.0"]
96104
build-backend = "poetry.core.masonry.api"

tests/test_calibrator.py

+4
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@
3838
from black_it.search_space import SearchSpace
3939

4040
from .fixtures.test_models import NormalMV # type: ignore
41+
from .utils.base import no_gpy_installed, no_python311_for_gpy, no_xgboost_installed
4142

4243

44+
@no_python311_for_gpy
45+
@no_gpy_installed
46+
@no_xgboost_installed
4347
class TestCalibrate: # pylint: disable=too-many-instance-attributes,attribute-defined-outside-init
4448
"""Test the Calibrator.calibrate method."""
4549

tests/test_samplers/test_gaussian_process.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222

2323
from black_it.samplers.gaussian_process import GaussianProcessSampler, _AcquisitionTypes
2424
from black_it.search_space import SearchSpace
25+
from tests.utils.base import no_gpy_installed, no_python311_for_gpy
26+
27+
pytestmark = [no_python311_for_gpy, no_gpy_installed] # noqa
2528

2629

2730
class TestGaussianProcess2D: # pylint: disable=attribute-defined-outside-init

tests/test_samplers/test_xgboost.py

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
from black_it.search_space import SearchSpace
2424

2525
from ..fixtures.test_models import BH4 # type: ignore
26+
from ..utils.base import no_xgboost_installed
27+
28+
pytestmark = no_xgboost_installed # noqa
29+
2630

2731
expected_params = np.array([[0.24, 0.26], [0.26, 0.02], [0.08, 0.24], [0.15, 0.15]])
2832

tests/utils/base.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@
1616

1717
"""Generic utility functions."""
1818
import dataclasses
19+
import importlib
1920
import shutil
2021
import signal
2122
import subprocess # nosec B404
23+
import sys
24+
import types
2225
from functools import wraps
23-
from typing import Callable, List, Type, Union
26+
from typing import Any, Callable, List, Optional, Type, Union
2427

2528
import pytest
29+
from _pytest.mark.structures import MarkDecorator # type: ignore
2630

31+
from black_it._load_dependency import _GPY_PACKAGE_NAME, _XGBOOST_PACKAGE_NAME
2732
from tests.conftest import DEFAULT_SUBPROCESS_TIMEOUT
2833

2934

@@ -170,3 +175,30 @@ def wrapper(*args, **kwargs): # type: ignore
170175
return wrapper
171176

172177
return decorator
178+
179+
180+
def try_import_else_none(module_name: str) -> Optional[types.ModuleType]:
181+
"""Try to import a module; if it fails, return None."""
182+
try:
183+
return importlib.import_module(module_name)
184+
except ImportError:
185+
return None
186+
187+
188+
def try_import_else_skip(package_name: str, **skipif_kwargs: Any) -> MarkDecorator:
189+
"""Try to import the package; else skip the test(s)."""
190+
return pytest.mark.skipif(
191+
try_import_else_none(package_name) is None,
192+
reason=f"Cannot run the test because the package '{package_name}' is not installed",
193+
**skipif_kwargs,
194+
)
195+
196+
197+
no_python311_for_gpy = pytest.mark.skipif(
198+
(3, 11) <= sys.version_info < (3, 12),
199+
reason="GPy not supported on Python 3.11, see: https://github.com/bancaditalia/black-it/issues/36",
200+
)
201+
202+
203+
no_gpy_installed = try_import_else_skip(_GPY_PACKAGE_NAME)
204+
no_xgboost_installed = try_import_else_skip(_XGBOOST_PACKAGE_NAME)

tox.ini

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ basepython = python3
1111
[testenv]
1212
setenv =
1313
PYTHONPATH = {toxinidir}
14+
extras =
15+
all
16+
gp-sampler
17+
xgboost-sampler
1418
deps =
1519
pytest>=7.1.2,<7.2.0
1620
pytest-cov>=3.0.0,<3.1.0

0 commit comments

Comments
 (0)