Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix factorial2 and pytest fix #308 and #309 #319

Merged
merged 12 commits into from
Jun 3, 2024
36 changes: 24 additions & 12 deletions iodata/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
# --
"""Module for computing overlap of atomic orbital basis functions."""

from typing import Optional
from typing import Optional, Union

import attr
import numpy as np
import scipy.special
from numpy.typing import NDArray

from .basis import HORTON2_CONVENTIONS as OVERLAP_CONVENTIONS
from .basis import MolecularBasis, Shell, convert_conventions, iter_cart_alphabet
Expand All @@ -31,24 +32,35 @@
__all__ = ["OVERLAP_CONVENTIONS", "compute_overlap", "gob_cart_normalization"]


def factorial2(n, exact=False):
"""Wrap scipy.special.factorial2 to return 1.0 when the input is -1.
def factorial2(n: Union[int, NDArray[int]]) -> Union[int, NDArray[int]]:
"""Modifcied scipy.special.factorial2 that returns 1 when the input is -1.

This is a temporary workaround while we wait for Scipy's update.
To learn more, see https://github.com/scipy/scipy/issues/18409.

This function only supports integer (array) arguments,
because this is the only relevant use case in IOData.

Parameters
----------
n : int or np.ndarray
Values to calculate n!! for. If n={0, -1}, the return value is 1.
n
Values to calculate n!! for. If n==-1, the return value is 1.
For n < -1, the return value is 0.

"""
# Scipy 1.11.x returns an integer when n is an integer, but 1.10.x returns an array,
# so np.array(n) is passed to make sure the output is always an array.
out = scipy.special.factorial2(np.array(n), exact=exact)
out[out <= 0] = 1.0
out[out <= -2] = 0.0
return out
# Handle integer inputs
if isinstance(n, (int, np.integer)):
return 1 if n == -1 else scipy.special.factorial2(n, exact=True)

# Handle integer array inputs
if isinstance(n, np.ndarray):
if issubclass(n.dtype.type, (int, np.integer)):
result = scipy.special.factorial2(n, exact=True)
result[n == -1] = 1
return result
raise TypeError(f"Unsupported dtype of array n: {n.dtype}")

raise TypeError(f"Unsupported type of argument n: {type(n)}")


def compute_overlap(
Expand Down Expand Up @@ -236,7 +248,7 @@ def __init__(self, n_max):
self.binomials = [
[scipy.special.binom(n, i) for i in range(n + 1)] for n in range(n_max + 1)
]
facts = [factorial2(m, 2) for m in range(2 * n_max)]
facts = [factorial2(m) for m in range(2 * n_max)]
facts.insert(0, 1)
self.facts = np.array(facts)

Expand Down
26 changes: 25 additions & 1 deletion iodata/test/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,38 @@

from ..api import load_one
from ..basis import MolecularBasis, Shell, convert_conventions
from ..overlap import OVERLAP_CONVENTIONS, compute_overlap
from ..overlap import OVERLAP_CONVENTIONS, compute_overlap, factorial2

try:
from importlib_resources import as_file, files
except ImportError:
from importlib.resources import as_file, files


@pytest.mark.parametrize(
("inp", "out"), [(0, 1), (1, 1), (2, 2), (3, 3), (4, 8), (5, 15), (-1, 1), (-2, 0)]
)
def test_factorial2_integer_arguments(inp, out):
assert factorial2(inp) == out
assert isinstance(factorial2(inp), int)


def test_factorial2_float_arguments():
with pytest.raises(TypeError):
factorial2(1.0)


def test_factorial2_integer_array_argument():
assert (factorial2(np.array([-2, -1, 4, 5])) == np.array([0, 1, 8, 15])).all()
assert (factorial2(np.array([[-2, -1], [4, 5]])) == np.array([[0, 1], [8, 15]])).all()
assert issubclass(factorial2(np.array([-2, -1, 4, 5])).dtype.type, np.integer)


def test_factorial2_float_array_argument():
with pytest.raises(TypeError):
factorial2(np.array([0.0, 1.0, 2.0, 3.0]))


def test_normalization_basics_segmented():
for angmom in range(7):
shells = [Shell(0, [angmom], ["c"], np.array([0.23]), np.array([[1.0]]))]
Expand Down