diff --git a/cmaes/_cma.py b/cmaes/_cma.py index 5f43e1c..19e67ca 100644 --- a/cmaes/_cma.py +++ b/cmaes/_cma.py @@ -87,7 +87,7 @@ def __init__( ), f"Abs of all elements of mean vector must be less than {_MEAN_MAX}" n_dim = len(mean) - assert n_dim > 1, "The dimension of mean must be larger than 1" + assert n_dim > 0, "The dimension of mean must be positive" if population_size is None: population_size = 4 + math.floor(3 * math.log(n_dim)) # (eq. 48) diff --git a/tests/test_fuzzing.py b/tests/test_fuzzing.py index a567607..da753cf 100644 --- a/tests/test_fuzzing.py +++ b/tests/test_fuzzing.py @@ -10,7 +10,7 @@ class TestFuzzing(unittest.TestCase): data=st.data(), ) def test_cma_tell(self, data): - dim = data.draw(st.integers(min_value=2, max_value=100)) + dim = data.draw(st.integers(min_value=1, max_value=100)) mean = data.draw(npst.arrays(dtype=float, shape=dim)) sigma = data.draw(st.floats(min_value=1e-16)) n_iterations = data.draw(st.integers(min_value=1))