Skip to content

Commit bd5c2fe

Browse files
marcofavoritomarcofavoritobi
authored andcommitted
test: add tests for the DependencyNotInstalled exception for xgboost-sampler and gp-sampler
1 parent b9a8834 commit bd5c2fe

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

tests/test_samplers/test_gaussian_process.py

+15
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
import pytest
2121
from numpy.typing import NDArray
2222

23+
import black_it
2324
from black_it.samplers.gaussian_process import GaussianProcessSampler, _AcquisitionTypes
2425
from black_it.search_space import SearchSpace
2526
from tests.utils.base import no_gpy_installed, no_python311_for_gpy
27+
from tests.utils.extras_tester import generic_test_import_error
2628

2729
pytestmark = [no_python311_for_gpy, no_gpy_installed] # noqa
2830

@@ -137,3 +139,16 @@ def test_gaussian_process_sample_wrong_acquisition() -> None:
137139
f"got {wrong_acquisition}",
138140
):
139141
GaussianProcessSampler(4, acquisition=wrong_acquisition)
142+
143+
144+
def test_dependency_not_installed_error() -> None:
145+
"""Test the "DependencyNotInstalled" exception in case the dependencies of the component are not installed."""
146+
generic_test_import_error(
147+
module_obj=black_it.samplers.gaussian_process,
148+
import_error_global_variable_name="_GPY_IMPORT_ERROR",
149+
component_initializer=lambda: GaussianProcessSampler(4),
150+
expected_message_pattern=(
151+
r"Cannot import package 'GPy', required by component GaussianProcessSampler\. "
152+
r"To solve the issue, you can install the extra 'gp-sampler': pip install black-it\[gp-sampler\]"
153+
),
154+
)

tests/test_samplers/test_xgboost.py

+15
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020

21+
import black_it
2122
from black_it.calibrator import Calibrator
2223
from black_it.loss_functions.msm import MethodOfMomentsLoss
2324
from black_it.samplers.halton import HaltonSampler
@@ -26,6 +27,7 @@
2627

2728
from ..fixtures.test_models import BH4 # type: ignore
2829
from ..utils.base import no_xgboost_installed
30+
from ..utils.extras_tester import generic_test_import_error
2931

3032
pytestmark = no_xgboost_installed # noqa
3133

@@ -136,3 +138,16 @@ def test_clip_losses() -> None:
136138
assert (
137139
y2 == np.array([0.0, MIN_FLOAT32 + EPS_FLOAT32, MAX_FLOAT32 - EPS_FLOAT32])
138140
).all()
141+
142+
143+
def test_dependency_not_installed_error() -> None:
144+
"""Test the "DependencyNotInstalled" exception in case the dependencies of the component are not installed."""
145+
generic_test_import_error(
146+
module_obj=black_it.samplers.xgboost,
147+
import_error_global_variable_name="_XGBOOST_IMPORT_ERROR",
148+
component_initializer=lambda: XGBoostSampler(4),
149+
expected_message_pattern=(
150+
r"Cannot import package 'xgboost', required by component XGBoostSampler. "
151+
r"To solve the issue, you can install the extra 'xgboost-sampler': pip install black-it\[xgboost-sampler\]"
152+
),
153+
)

tests/utils/extras_tester.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Black-box ABM Calibration Kit (Black-it)
2+
# Copyright (C) 2021-2023 Banca d'Italia
3+
#
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the GNU Affero General Public License as
6+
# published by the Free Software Foundation, either version 3 of the
7+
# License, or (at your option) any later version.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
# GNU Affero General Public License for more details.
13+
#
14+
# You should have received a copy of the GNU Affero General Public License
15+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
17+
"""Test utilities to test the correct working of the extras mechanism."""
18+
import contextlib
19+
from types import ModuleType
20+
from typing import Any, Callable, Generator
21+
22+
import pytest
23+
24+
from black_it._load_dependency import DependencyNotInstalled
25+
26+
27+
@contextlib.contextmanager
28+
def _change_variable_value(
29+
module_obj: ModuleType, variable_name: str, value: Any
30+
) -> Generator[None, None, None]:
31+
"""Change, temporarily, the value of a variable in a module."""
32+
old_value = getattr(module_obj, variable_name)
33+
setattr(module_obj, variable_name, value)
34+
yield
35+
setattr(module_obj, variable_name, old_value)
36+
37+
38+
def generic_test_import_error(
39+
module_obj: ModuleType,
40+
import_error_global_variable_name: str,
41+
component_initializer: Callable,
42+
expected_message_pattern: str,
43+
) -> None:
44+
"""
45+
Test that the correct exception is raised when a dependency is not installed.
46+
47+
This function is an utility testing function to test that the correct exception is raised when a dependency is not
48+
installed.
49+
50+
It assumes the module under testing, associated to some Black-it component (e.g. loss, sampler, etc.), has a global
51+
variable named import_error_global_variable_name, which is set to None if the dependency is installed, or to a
52+
DependencyNotInstalled exception if the dependency is not installed. This convention is the one used in the
53+
Black-it code.
54+
55+
For example, samplers.xgboost has a global variable named _XGBOOST_IMPORT_ERROR, which is set to None if the xgboost
56+
package is installed, or to an DependencyNotInstalled if the xgboost package is not installed. Then, during the
57+
initialization of XGBoostSampler, the value of the _XGBOOST_IMPORT_ERROR variable is checked, and if it is not
58+
None, an exception is raised.
59+
60+
This test function checks that the correct exception is raised when the import_error_global_variable_name variable
61+
is set to None and we try to initialize the Black-it component.
62+
63+
Args:
64+
module_obj: the module object to test
65+
import_error_global_variable_name: the name of the global variable in the module object
66+
component_initializer: the function to call to initialize the component under testing
67+
expected_message_pattern: the pattern of the expected exception message
68+
"""
69+
with _change_variable_value(
70+
module_obj, import_error_global_variable_name, ImportError("fake error")
71+
):
72+
with pytest.raises(DependencyNotInstalled, match=expected_message_pattern):
73+
component_initializer()

0 commit comments

Comments
 (0)