Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix factorial2 and pytest fix #308 and #309 #319

Merged
merged 12 commits into from
Jun 3, 2024
63 changes: 47 additions & 16 deletions iodata/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,56 @@
import scipy.special

from .basis import HORTON2_CONVENTIONS as OVERLAP_CONVENTIONS
from .basis import MolecularBasis, Shell, convert_conventions, iter_cart_alphabet
from .basis import MolecularBasis, convert_conventions, iter_cart_alphabet
from .overlap_cartpure import tfs

__all__ = ["OVERLAP_CONVENTIONS", "compute_overlap", "gob_cart_normalization"]


def factorial2(n, exact=False):
"""Wrap scipy.special.factorial2 to return 1.0 when the input is -1.
"""Wrap scipy.special.factorial2 to return 1.0 when the input is -1 and handle float arrays.

This is a temporary workaround while we wait for Scipy's update.
To learn more, see https://github.com/scipy/scipy/issues/18409.

Parameters
----------
n : int or np.ndarray
n : int, float, or np.ndarray
Values to calculate n!! for. If n={0, -1}, the return value is 1.
For n < -1, the return value is 0.
"""
# Scipy 1.11.x returns an integer when n is an integer, but 1.10.x returns an array,
# so np.array(n) is passed to make sure the output is always an array.
out = scipy.special.factorial2(np.array(n), exact=exact)
out[out <= 0] = 1.0
out[out <= -2] = 0.0
return out


# Handle integer inputs
if isinstance(n, (int, np.integer)):
if n in {-1, 0}:
return 1.0
if n < -1:
return 0.0
return scipy.special.factorial2(n, exact=exact)

# Handle float inputs
if isinstance(n, float):
if n in {-1.0, 0.0}:
return 1.0
if n < -1.0:
return 0.0
return scipy.special.factorial2(int(n), exact=exact)

# Handle array inputs
if isinstance(n, np.ndarray):
result = np.zeros_like(n, dtype=float)
for i, val in np.ndenumerate(n):
if val in {-1.0, 0.0}:
result[i] = 1.0
elif val < -1.0:
result[i] = 0.0
else:
result[i] = scipy.special.factorial2(int(val), exact=exact)
return result

return None


# pylint: disable=too-many-nested-blocks,too-many-statements,too-many-branches
def compute_overlap(
obasis0: MolecularBasis,
atcoords0: np.ndarray,
Expand Down Expand Up @@ -131,7 +155,7 @@ def compute_overlap(

n_max = max(np.max(shell.angmoms) for shell in obasis0.shells)
if not identical:
n_max = max(n_max, *(np.max(shell.angmoms) for shell in obasis1.shells))
n_max = max(n_max, max(np.max(shell.angmoms) for shell in obasis1.shells))
go = GaussianOverlap(n_max)

# define a python ufunc (numpy function) for broadcasted calling over angular momentums
Expand All @@ -140,13 +164,17 @@ def compute_overlap(
# Loop over shell0
begin0 = 0

# pylint: disable=too-many-nested-blocks
for i0, shell0 in enumerate(obasis0.shells):
r0 = atcoords0[shell0.icenter]
end0 = begin0 + shell0.nbasis

# Loop over shell1 (lower triangular only, including diagonal)
begin1 = 0
nshell1 = i0 + 1 if identical else len(obasis1.shells)
if identical:
nshell1 = i0 + 1
else:
nshell1 = len(obasis1.shells)
for i1, shell1 in enumerate(obasis1.shells[:nshell1]):
r1 = atcoords1[shell1.icenter]
end1 = begin1 + shell1.nbasis
Expand Down Expand Up @@ -218,7 +246,8 @@ def compute_overlap(
permutation1, signs1 = permutation0, signs0
else:
permutation1, signs1 = convert_conventions(obasis1, OVERLAP_CONVENTIONS, reverse=True)
return overlap[:, permutation1] * signs1
overlap = overlap[:, permutation1] * signs1
return overlap


class GaussianOverlap:
Expand Down Expand Up @@ -253,7 +282,7 @@ def compute_overlap_gaussian_1d(self, x1, x2, n1, n2, two_at):
return value


def _compute_cart_shell_normalizations(shell: Shell) -> np.ndarray:
def _compute_cart_shell_normalizations(shell: "Shell") -> np.ndarray:
"""Return normalization constants for the primitives in a given shell.

Parameters
Expand All @@ -272,7 +301,9 @@ def _compute_cart_shell_normalizations(shell: Shell) -> np.ndarray:
result = []
for angmom in shell.angmoms:
for exponent in shell.exponents:
row = [gob_cart_normalization(exponent, n) for n in iter_cart_alphabet(angmom)]
row = []
for n in iter_cart_alphabet(angmom):
row.append(gob_cart_normalization(exponent, n))
result.append(row)
return np.array(result)

Expand Down
42 changes: 42 additions & 0 deletions iodata/test/test_factorial2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import pytest

from ..overlap import factorial2


def test_integer_arguments():
assert factorial2(0, exact=True) == 1
assert factorial2(1, exact=True) == 1
assert factorial2(2, exact=True) == 2
assert factorial2(3, exact=True) == 3
assert factorial2(-1, exact=True) == 1
assert factorial2(-2, exact=True) == 0


def test_float_arguments():
assert factorial2(0.0, exact=False) == pytest.approx(1.0)
assert factorial2(1.0, exact=False) == pytest.approx(1.0)
assert factorial2(2.0, exact=False) == pytest.approx(2.0)
assert factorial2(3.0, exact=False) == pytest.approx(3.0)


def test_integer_array_argument():
np.testing.assert_array_equal(
factorial2(np.array([0, 1, 2, 3]), exact=True), np.array([1, 1, 2, 3])
)


def test_float_array_argument():
np.testing.assert_array_almost_equal(
factorial2(np.array([0.0, 1.0, 2.0, 3.0]), exact=False), np.array([1.0, 1.0, 2.0, 3.0])
)


def test_special_cases_exact():
assert factorial2(-1, exact=True) == pytest.approx(1.0)
assert factorial2(-2, exact=True) == pytest.approx(0.0)


def test_special_cases_not_exact():
assert factorial2(-1.0, exact=False) == pytest.approx(1.0)
assert factorial2(-2.0, exact=False) == pytest.approx(0.0)