From 0a48fc605be6ca8231604447749624829fcb749d Mon Sep 17 00:00:00 2001 From: D-TheProgrammer <151149998+D-TheProgrammer@users.noreply.github.com> Date: Thu, 30 May 2024 19:13:28 +0200 Subject: [PATCH 01/12] New Fcatorial2 respecting linters --- iodata/overlap.py | 63 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/iodata/overlap.py b/iodata/overlap.py index b2671a97..351ed61c 100644 --- a/iodata/overlap.py +++ b/iodata/overlap.py @@ -25,32 +25,56 @@ import scipy.special from .basis import HORTON2_CONVENTIONS as OVERLAP_CONVENTIONS -from .basis import MolecularBasis, Shell, convert_conventions, iter_cart_alphabet +from .basis import MolecularBasis, convert_conventions, iter_cart_alphabet from .overlap_cartpure import tfs __all__ = ["OVERLAP_CONVENTIONS", "compute_overlap", "gob_cart_normalization"] def factorial2(n, exact=False): - """Wrap scipy.special.factorial2 to return 1.0 when the input is -1. + """Wrap scipy.special.factorial2 to return 1.0 when the input is -1 and handle float arrays. This is a temporary workaround while we wait for Scipy's update. To learn more, see https://github.com/scipy/scipy/issues/18409. Parameters ---------- - n : int or np.ndarray + n : int, float, or np.ndarray Values to calculate n!! for. If n={0, -1}, the return value is 1. For n < -1, the return value is 0. """ - # Scipy 1.11.x returns an integer when n is an integer, but 1.10.x returns an array, - # so np.array(n) is passed to make sure the output is always an array. - out = scipy.special.factorial2(np.array(n), exact=exact) - out[out <= 0] = 1.0 - out[out <= -2] = 0.0 - return out - - + # Handle integer inputs + if isinstance(n, (int, np.integer)): + if n in {-1, 0}: + return 1.0 + if n < -1: + return 0.0 + return scipy.special.factorial2(n, exact=exact) + + # Handle float inputs + if isinstance(n, float): + if n in {-1.0, 0.0}: + return 1.0 + if n < -1.0: + return 0.0 + return scipy.special.factorial2(int(n), exact=exact) + + # Handle array inputs + if isinstance(n, np.ndarray): + result = np.zeros_like(n, dtype=float) + for i, val in np.ndenumerate(n): + if val in {-1.0, 0.0}: + result[i] = 1.0 + elif val < -1.0: + result[i] = 0.0 + else: + result[i] = scipy.special.factorial2(int(val), exact=exact) + return result + + return None + + +# pylint: disable=too-many-nested-blocks,too-many-statements,too-many-branches def compute_overlap( obasis0: MolecularBasis, atcoords0: np.ndarray, @@ -131,7 +155,7 @@ def compute_overlap( n_max = max(np.max(shell.angmoms) for shell in obasis0.shells) if not identical: - n_max = max(n_max, *(np.max(shell.angmoms) for shell in obasis1.shells)) + n_max = max(n_max, max(np.max(shell.angmoms) for shell in obasis1.shells)) go = GaussianOverlap(n_max) # define a python ufunc (numpy function) for broadcasted calling over angular momentums @@ -140,13 +164,17 @@ def compute_overlap( # Loop over shell0 begin0 = 0 + # pylint: disable=too-many-nested-blocks for i0, shell0 in enumerate(obasis0.shells): r0 = atcoords0[shell0.icenter] end0 = begin0 + shell0.nbasis # Loop over shell1 (lower triangular only, including diagonal) begin1 = 0 - nshell1 = i0 + 1 if identical else len(obasis1.shells) + if identical: + nshell1 = i0 + 1 + else: + nshell1 = len(obasis1.shells) for i1, shell1 in enumerate(obasis1.shells[:nshell1]): r1 = atcoords1[shell1.icenter] end1 = begin1 + shell1.nbasis @@ -218,7 +246,8 @@ def compute_overlap( permutation1, signs1 = permutation0, signs0 else: permutation1, signs1 = convert_conventions(obasis1, OVERLAP_CONVENTIONS, reverse=True) - return overlap[:, permutation1] * signs1 + overlap = overlap[:, permutation1] * signs1 + return overlap class GaussianOverlap: @@ -253,7 +282,7 @@ def compute_overlap_gaussian_1d(self, x1, x2, n1, n2, two_at): return value -def _compute_cart_shell_normalizations(shell: Shell) -> np.ndarray: +def _compute_cart_shell_normalizations(shell: "Shell") -> np.ndarray: """Return normalization constants for the primitives in a given shell. Parameters @@ -272,7 +301,9 @@ def _compute_cart_shell_normalizations(shell: Shell) -> np.ndarray: result = [] for angmom in shell.angmoms: for exponent in shell.exponents: - row = [gob_cart_normalization(exponent, n) for n in iter_cart_alphabet(angmom)] + row = [] + for n in iter_cart_alphabet(angmom): + row.append(gob_cart_normalization(exponent, n)) result.append(row) return np.array(result) From 58b65594086cf9c4cf59a7ea68ca980aae0a50a5 Mon Sep 17 00:00:00 2001 From: D-TheProgrammer <151149998+D-TheProgrammer@users.noreply.github.com> Date: Thu, 30 May 2024 19:14:50 +0200 Subject: [PATCH 02/12] Pytest for factorial2 --- iodata/test/test_factorial2.py | 42 ++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 iodata/test/test_factorial2.py diff --git a/iodata/test/test_factorial2.py b/iodata/test/test_factorial2.py new file mode 100644 index 00000000..e790b494 --- /dev/null +++ b/iodata/test/test_factorial2.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest + +from ..overlap import factorial2 + + +def test_integer_arguments(): + assert factorial2(0, exact=True) == 1 + assert factorial2(1, exact=True) == 1 + assert factorial2(2, exact=True) == 2 + assert factorial2(3, exact=True) == 3 + assert factorial2(-1, exact=True) == 1 + assert factorial2(-2, exact=True) == 0 + + +def test_float_arguments(): + assert factorial2(0.0, exact=False) == pytest.approx(1.0) + assert factorial2(1.0, exact=False) == pytest.approx(1.0) + assert factorial2(2.0, exact=False) == pytest.approx(2.0) + assert factorial2(3.0, exact=False) == pytest.approx(3.0) + + +def test_integer_array_argument(): + np.testing.assert_array_equal( + factorial2(np.array([0, 1, 2, 3]), exact=True), np.array([1, 1, 2, 3]) + ) + + +def test_float_array_argument(): + np.testing.assert_array_almost_equal( + factorial2(np.array([0.0, 1.0, 2.0, 3.0]), exact=False), np.array([1.0, 1.0, 2.0, 3.0]) + ) + + +def test_special_cases_exact(): + assert factorial2(-1, exact=True) == pytest.approx(1.0) + assert factorial2(-2, exact=True) == pytest.approx(0.0) + + +def test_special_cases_not_exact(): + assert factorial2(-1.0, exact=False) == pytest.approx(1.0) + assert factorial2(-2.0, exact=False) == pytest.approx(0.0) From 0d639a781cb1a94edc962de7c2fc732a001833f2 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 10:42:36 +0200 Subject: [PATCH 03/12] Fix Ruff issues --- iodata/overlap.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/iodata/overlap.py b/iodata/overlap.py index 351ed61c..8e1bf199 100644 --- a/iodata/overlap.py +++ b/iodata/overlap.py @@ -25,7 +25,7 @@ import scipy.special from .basis import HORTON2_CONVENTIONS as OVERLAP_CONVENTIONS -from .basis import MolecularBasis, convert_conventions, iter_cart_alphabet +from .basis import MolecularBasis, Shell, convert_conventions, iter_cart_alphabet from .overlap_cartpure import tfs __all__ = ["OVERLAP_CONVENTIONS", "compute_overlap", "gob_cart_normalization"] @@ -155,7 +155,7 @@ def compute_overlap( n_max = max(np.max(shell.angmoms) for shell in obasis0.shells) if not identical: - n_max = max(n_max, max(np.max(shell.angmoms) for shell in obasis1.shells)) + n_max = max(n_max, *(np.max(shell.angmoms) for shell in obasis1.shells)) go = GaussianOverlap(n_max) # define a python ufunc (numpy function) for broadcasted calling over angular momentums @@ -171,10 +171,7 @@ def compute_overlap( # Loop over shell1 (lower triangular only, including diagonal) begin1 = 0 - if identical: - nshell1 = i0 + 1 - else: - nshell1 = len(obasis1.shells) + nshell1 = i0 + 1 if identical else len(obasis1.shells) for i1, shell1 in enumerate(obasis1.shells[:nshell1]): r1 = atcoords1[shell1.icenter] end1 = begin1 + shell1.nbasis @@ -246,8 +243,7 @@ def compute_overlap( permutation1, signs1 = permutation0, signs0 else: permutation1, signs1 = convert_conventions(obasis1, OVERLAP_CONVENTIONS, reverse=True) - overlap = overlap[:, permutation1] * signs1 - return overlap + return overlap[:, permutation1] * signs1 class GaussianOverlap: @@ -301,9 +297,7 @@ def _compute_cart_shell_normalizations(shell: "Shell") -> np.ndarray: result = [] for angmom in shell.angmoms: for exponent in shell.exponents: - row = [] - for n in iter_cart_alphabet(angmom): - row.append(gob_cart_normalization(exponent, n)) + row = [gob_cart_normalization(exponent, n) for n in iter_cart_alphabet(angmom)] result.append(row) return np.array(result) From 2f8d75241e45e7f1ab42d26c715f0134abe95c3d Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 10:44:34 +0200 Subject: [PATCH 04/12] Fix conditionals: avoid sets --- iodata/overlap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/iodata/overlap.py b/iodata/overlap.py index 8e1bf199..29a4b386 100644 --- a/iodata/overlap.py +++ b/iodata/overlap.py @@ -45,7 +45,7 @@ def factorial2(n, exact=False): """ # Handle integer inputs if isinstance(n, (int, np.integer)): - if n in {-1, 0}: + if n == -1: return 1.0 if n < -1: return 0.0 @@ -53,7 +53,7 @@ def factorial2(n, exact=False): # Handle float inputs if isinstance(n, float): - if n in {-1.0, 0.0}: + if n == -1.0: return 1.0 if n < -1.0: return 0.0 @@ -63,7 +63,7 @@ def factorial2(n, exact=False): if isinstance(n, np.ndarray): result = np.zeros_like(n, dtype=float) for i, val in np.ndenumerate(n): - if val in {-1.0, 0.0}: + if val == -1.0: result[i] = 1.0 elif val < -1.0: result[i] = 0.0 From 10e343cdcad52c66a27a196bd7d2fea1e1ef4445 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 11:11:50 +0200 Subject: [PATCH 05/12] Restrict factorial2 to integer and exact use case --- iodata/overlap.py | 55 ++++++++++++++++------------------ iodata/test/test_factorial2.py | 36 ++++++++-------------- 2 files changed, 38 insertions(+), 53 deletions(-) diff --git a/iodata/overlap.py b/iodata/overlap.py index 29a4b386..8f1de735 100644 --- a/iodata/overlap.py +++ b/iodata/overlap.py @@ -18,11 +18,12 @@ # -- """Module for computing overlap of atomic orbital basis functions.""" -from typing import Optional +from typing import Optional, Union import attr import numpy as np import scipy.special +from numpy.typing import NDArray from .basis import HORTON2_CONVENTIONS as OVERLAP_CONVENTIONS from .basis import MolecularBasis, Shell, convert_conventions, iter_cart_alphabet @@ -31,47 +32,41 @@ __all__ = ["OVERLAP_CONVENTIONS", "compute_overlap", "gob_cart_normalization"] -def factorial2(n, exact=False): - """Wrap scipy.special.factorial2 to return 1.0 when the input is -1 and handle float arrays. +def factorial2(n: Union[int, NDArray[int]]) -> Union[int, NDArray[int]]: + """Modifcied scipy.special.factorial2 to return 1 when the input is -1. This is a temporary workaround while we wait for Scipy's update. To learn more, see https://github.com/scipy/scipy/issues/18409. + This function only supports integer (array) arguments, + because this is the only relevant use case in IOData. + Parameters ---------- - n : int, float, or np.ndarray - Values to calculate n!! for. If n={0, -1}, the return value is 1. + n + Values to calculate n!! for. If n==-1, the return value is 1. For n < -1, the return value is 0. + """ # Handle integer inputs if isinstance(n, (int, np.integer)): if n == -1: - return 1.0 - if n < -1: - return 0.0 - return scipy.special.factorial2(n, exact=exact) - - # Handle float inputs - if isinstance(n, float): - if n == -1.0: - return 1.0 - if n < -1.0: - return 0.0 - return scipy.special.factorial2(int(n), exact=exact) - - # Handle array inputs + return 1 + return scipy.special.factorial2(n, exact=True) + + # Handle integer array inputs if isinstance(n, np.ndarray): - result = np.zeros_like(n, dtype=float) - for i, val in np.ndenumerate(n): - if val == -1.0: - result[i] = 1.0 - elif val < -1.0: - result[i] = 0.0 - else: - result[i] = scipy.special.factorial2(int(val), exact=exact) - return result + if issubclass(n.dtype.type, (int, np.integer)): + result = np.zeros_like(n) + for i, val in np.ndenumerate(n): + if val == -1: + result[i] = 1 + else: + result[i] = scipy.special.factorial2(val, exact=True) + return result + raise TypeError(f"Unsupported dtype of array n: {n.dtype}") - return None + raise TypeError(f"Unsupported type of argument n: {type(n)}") # pylint: disable=too-many-nested-blocks,too-many-statements,too-many-branches @@ -261,7 +256,7 @@ def __init__(self, n_max): self.binomials = [ [scipy.special.binom(n, i) for i in range(n + 1)] for n in range(n_max + 1) ] - facts = [factorial2(m, 2) for m in range(2 * n_max)] + facts = [factorial2(m) for m in range(2 * n_max)] facts.insert(0, 1) self.facts = np.array(facts) diff --git a/iodata/test/test_factorial2.py b/iodata/test/test_factorial2.py index e790b494..cf543d75 100644 --- a/iodata/test/test_factorial2.py +++ b/iodata/test/test_factorial2.py @@ -5,38 +5,28 @@ def test_integer_arguments(): - assert factorial2(0, exact=True) == 1 - assert factorial2(1, exact=True) == 1 - assert factorial2(2, exact=True) == 2 - assert factorial2(3, exact=True) == 3 - assert factorial2(-1, exact=True) == 1 - assert factorial2(-2, exact=True) == 0 + assert factorial2(0) == 1 + assert factorial2(1) == 1 + assert factorial2(2) == 2 + assert factorial2(3) == 3 + assert factorial2(-1) == 1 + assert factorial2(-2) == 0 def test_float_arguments(): - assert factorial2(0.0, exact=False) == pytest.approx(1.0) - assert factorial2(1.0, exact=False) == pytest.approx(1.0) - assert factorial2(2.0, exact=False) == pytest.approx(2.0) - assert factorial2(3.0, exact=False) == pytest.approx(3.0) + with pytest.raises(TypeError): + assert factorial2(1.0) def test_integer_array_argument(): - np.testing.assert_array_equal( - factorial2(np.array([0, 1, 2, 3]), exact=True), np.array([1, 1, 2, 3]) - ) + assert (factorial2(np.array([0, 1, 2, 3])) == np.array([1, 1, 2, 3])).all() def test_float_array_argument(): - np.testing.assert_array_almost_equal( - factorial2(np.array([0.0, 1.0, 2.0, 3.0]), exact=False), np.array([1.0, 1.0, 2.0, 3.0]) - ) + with pytest.raises(TypeError): + factorial2(np.array([0.0, 1.0, 2.0, 3.0])) def test_special_cases_exact(): - assert factorial2(-1, exact=True) == pytest.approx(1.0) - assert factorial2(-2, exact=True) == pytest.approx(0.0) - - -def test_special_cases_not_exact(): - assert factorial2(-1.0, exact=False) == pytest.approx(1.0) - assert factorial2(-2.0, exact=False) == pytest.approx(0.0) + assert factorial2(-1) == pytest.approx(1) + assert factorial2(-2) == pytest.approx(0) From 0a33bf77d1b90abde25b4f0dabe5d1204073ebe0 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 11:12:45 +0200 Subject: [PATCH 06/12] Move factorial2 tests to test_overlap.py --- iodata/test/test_factorial2.py | 32 -------------------------------- iodata/test/test_overlap.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 33 deletions(-) delete mode 100644 iodata/test/test_factorial2.py diff --git a/iodata/test/test_factorial2.py b/iodata/test/test_factorial2.py deleted file mode 100644 index cf543d75..00000000 --- a/iodata/test/test_factorial2.py +++ /dev/null @@ -1,32 +0,0 @@ -import numpy as np -import pytest - -from ..overlap import factorial2 - - -def test_integer_arguments(): - assert factorial2(0) == 1 - assert factorial2(1) == 1 - assert factorial2(2) == 2 - assert factorial2(3) == 3 - assert factorial2(-1) == 1 - assert factorial2(-2) == 0 - - -def test_float_arguments(): - with pytest.raises(TypeError): - assert factorial2(1.0) - - -def test_integer_array_argument(): - assert (factorial2(np.array([0, 1, 2, 3])) == np.array([1, 1, 2, 3])).all() - - -def test_float_array_argument(): - with pytest.raises(TypeError): - factorial2(np.array([0.0, 1.0, 2.0, 3.0])) - - -def test_special_cases_exact(): - assert factorial2(-1) == pytest.approx(1) - assert factorial2(-2) == pytest.approx(0) diff --git a/iodata/test/test_overlap.py b/iodata/test/test_overlap.py index 64a76fe6..d775dd2e 100644 --- a/iodata/test/test_overlap.py +++ b/iodata/test/test_overlap.py @@ -27,7 +27,7 @@ from ..api import load_one from ..basis import MolecularBasis, Shell, convert_conventions -from ..overlap import OVERLAP_CONVENTIONS, compute_overlap +from ..overlap import OVERLAP_CONVENTIONS, compute_overlap, factorial2 try: from importlib_resources import as_file, files @@ -35,6 +35,34 @@ from importlib.resources import as_file, files +def test_integer_arguments(): + assert factorial2(0) == 1 + assert factorial2(1) == 1 + assert factorial2(2) == 2 + assert factorial2(3) == 3 + assert factorial2(-1) == 1 + assert factorial2(-2) == 0 + + +def test_float_arguments(): + with pytest.raises(TypeError): + assert factorial2(1.0) + + +def test_integer_array_argument(): + assert (factorial2(np.array([0, 1, 2, 3])) == np.array([1, 1, 2, 3])).all() + + +def test_float_array_argument(): + with pytest.raises(TypeError): + factorial2(np.array([0.0, 1.0, 2.0, 3.0])) + + +def test_special_cases_exact(): + assert factorial2(-1) == pytest.approx(1) + assert factorial2(-2) == pytest.approx(0) + + def test_normalization_basics_segmented(): for angmom in range(7): shells = [Shell(0, [angmom], ["c"], np.array([0.23]), np.array([[1.0]]))] From 128edbba791152e6d145d2988db2008e44116658 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 11:14:20 +0200 Subject: [PATCH 07/12] Fix type test factorial2 --- iodata/test/test_overlap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iodata/test/test_overlap.py b/iodata/test/test_overlap.py index d775dd2e..2f68a0dc 100644 --- a/iodata/test/test_overlap.py +++ b/iodata/test/test_overlap.py @@ -46,7 +46,7 @@ def test_integer_arguments(): def test_float_arguments(): with pytest.raises(TypeError): - assert factorial2(1.0) + factorial2(1.0) def test_integer_array_argument(): From 99d332e710bdf7ed217f3dbede02933aff5ade47 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 11:19:35 +0200 Subject: [PATCH 08/12] Remove pylint comments --- iodata/overlap.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/iodata/overlap.py b/iodata/overlap.py index 8f1de735..6e143de5 100644 --- a/iodata/overlap.py +++ b/iodata/overlap.py @@ -69,7 +69,6 @@ def factorial2(n: Union[int, NDArray[int]]) -> Union[int, NDArray[int]]: raise TypeError(f"Unsupported type of argument n: {type(n)}") -# pylint: disable=too-many-nested-blocks,too-many-statements,too-many-branches def compute_overlap( obasis0: MolecularBasis, atcoords0: np.ndarray, @@ -159,7 +158,6 @@ def compute_overlap( # Loop over shell0 begin0 = 0 - # pylint: disable=too-many-nested-blocks for i0, shell0 in enumerate(obasis0.shells): r0 = atcoords0[shell0.icenter] end0 = begin0 + shell0.nbasis From b657183dc3c872d904d79a3ae3ace93de1570e2e Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 11:20:54 +0200 Subject: [PATCH 09/12] Improve docstring --- iodata/overlap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iodata/overlap.py b/iodata/overlap.py index 6e143de5..0ccb1350 100644 --- a/iodata/overlap.py +++ b/iodata/overlap.py @@ -33,7 +33,7 @@ def factorial2(n: Union[int, NDArray[int]]) -> Union[int, NDArray[int]]: - """Modifcied scipy.special.factorial2 to return 1 when the input is -1. + """Modifcied scipy.special.factorial2 that returns 1 when the input is -1. This is a temporary workaround while we wait for Scipy's update. To learn more, see https://github.com/scipy/scipy/issues/18409. From 5b8ec3b95eceee08e2639ef33d3b76844936f010 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 11:21:21 +0200 Subject: [PATCH 10/12] Remove redundant quotes --- iodata/overlap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iodata/overlap.py b/iodata/overlap.py index 0ccb1350..c9fc68e8 100644 --- a/iodata/overlap.py +++ b/iodata/overlap.py @@ -271,7 +271,7 @@ def compute_overlap_gaussian_1d(self, x1, x2, n1, n2, two_at): return value -def _compute_cart_shell_normalizations(shell: "Shell") -> np.ndarray: +def _compute_cart_shell_normalizations(shell: Shell) -> np.ndarray: """Return normalization constants for the primitives in a given shell. Parameters From f539ff7a89ab124941430dacd6a92479a93c91eb Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 11:32:24 +0200 Subject: [PATCH 11/12] Reduce complexity and vectorize --- iodata/overlap.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/iodata/overlap.py b/iodata/overlap.py index c9fc68e8..85a9d8af 100644 --- a/iodata/overlap.py +++ b/iodata/overlap.py @@ -50,19 +50,13 @@ def factorial2(n: Union[int, NDArray[int]]) -> Union[int, NDArray[int]]: """ # Handle integer inputs if isinstance(n, (int, np.integer)): - if n == -1: - return 1 - return scipy.special.factorial2(n, exact=True) + return 1 if n == -1 else scipy.special.factorial2(n, exact=True) # Handle integer array inputs if isinstance(n, np.ndarray): if issubclass(n.dtype.type, (int, np.integer)): - result = np.zeros_like(n) - for i, val in np.ndenumerate(n): - if val == -1: - result[i] = 1 - else: - result[i] = scipy.special.factorial2(val, exact=True) + result = scipy.special.factorial2(n, exact=True) + result[n == -1] = 1 return result raise TypeError(f"Unsupported dtype of array n: {n.dtype}") From d40b54380588757ee97a89dba4456677fd75954f Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 3 Jun 2024 19:31:23 +0200 Subject: [PATCH 12/12] Factorial2 test improvements - Include type checks - Include other arrays - Include more cases --- iodata/test/test_overlap.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/iodata/test/test_overlap.py b/iodata/test/test_overlap.py index 2f68a0dc..9a2e5bc4 100644 --- a/iodata/test/test_overlap.py +++ b/iodata/test/test_overlap.py @@ -35,34 +35,30 @@ from importlib.resources import as_file, files -def test_integer_arguments(): - assert factorial2(0) == 1 - assert factorial2(1) == 1 - assert factorial2(2) == 2 - assert factorial2(3) == 3 - assert factorial2(-1) == 1 - assert factorial2(-2) == 0 +@pytest.mark.parametrize( + ("inp", "out"), [(0, 1), (1, 1), (2, 2), (3, 3), (4, 8), (5, 15), (-1, 1), (-2, 0)] +) +def test_factorial2_integer_arguments(inp, out): + assert factorial2(inp) == out + assert isinstance(factorial2(inp), int) -def test_float_arguments(): +def test_factorial2_float_arguments(): with pytest.raises(TypeError): factorial2(1.0) -def test_integer_array_argument(): - assert (factorial2(np.array([0, 1, 2, 3])) == np.array([1, 1, 2, 3])).all() +def test_factorial2_integer_array_argument(): + assert (factorial2(np.array([-2, -1, 4, 5])) == np.array([0, 1, 8, 15])).all() + assert (factorial2(np.array([[-2, -1], [4, 5]])) == np.array([[0, 1], [8, 15]])).all() + assert issubclass(factorial2(np.array([-2, -1, 4, 5])).dtype.type, np.integer) -def test_float_array_argument(): +def test_factorial2_float_array_argument(): with pytest.raises(TypeError): factorial2(np.array([0.0, 1.0, 2.0, 3.0])) -def test_special_cases_exact(): - assert factorial2(-1) == pytest.approx(1) - assert factorial2(-2) == pytest.approx(0) - - def test_normalization_basics_segmented(): for angmom in range(7): shells = [Shell(0, [angmom], ["c"], np.array([0.23]), np.array([[1.0]]))]