Skip to content

Commit

Permalink
Migrate from attr to attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
tovrstra committed May 29, 2024
1 parent 486e413 commit 99e0f94
Show file tree
Hide file tree
Showing 25 changed files with 193 additions and 180 deletions.
31 changes: 16 additions & 15 deletions iodata/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from numbers import Integral
from typing import Union

import attr
import attrs
import numpy as np
from numpy.typing import NDArray

from .attrutils import validate_shape

Expand Down Expand Up @@ -100,7 +101,7 @@ def angmom_its(angmom: Union[int, list[int]]) -> Union[str, list[str]]:
return ANGMOM_CHARS[angmom]


@attr.s(auto_attribs=True, slots=True, on_setattr=[attr.setters.validate, attr.setters.convert])
@attrs.define
class Shell:
"""A shell in a molecular basis representing (generalized) contractions with the same exponents.
Expand All @@ -126,11 +127,11 @@ class Shell:
"""

icenter: int
angmoms: list[int] = attr.ib(validator=validate_shape(("coeffs", 1)))
kinds: list[str] = attr.ib(validator=validate_shape(("coeffs", 1)))
exponents: np.ndarray = attr.ib(validator=validate_shape(("coeffs", 0)))
coeffs: np.ndarray = attr.ib(validator=validate_shape(("exponents", 0), ("kinds", 0)))
icenter: int = attrs.field()
angmoms: list[int] = attrs.field(validator=validate_shape(("coeffs", 1)))
kinds: list[str] = attrs.field(validator=validate_shape(("coeffs", 1)))
exponents: NDArray = attrs.field(validator=validate_shape(("coeffs", 0)))
coeffs: NDArray = attrs.field(validator=validate_shape(("exponents", 0), ("kinds", 0)))

@property
def nbasis(self) -> int:
Expand All @@ -156,7 +157,7 @@ def ncon(self) -> int:
return len(self.angmoms)


@attr.s(auto_attribs=True, slots=True, on_setattr=[attr.setters.validate, attr.setters.convert])
@attrs.define
class MolecularBasis:
"""A complete molecular orbital or density basis set.
Expand Down Expand Up @@ -205,9 +206,9 @@ class MolecularBasis:
"""

shells: list[Shell]
conventions: dict[str, str]
primitive_normalization: str
shells: list[Shell] = attrs.field()
conventions: dict[str, str] = attrs.field()
primitive_normalization: str = attrs.field()

@property
def nbasis(self) -> int:
Expand All @@ -222,12 +223,12 @@ def get_segmented(self):
shells.append(
Shell(shell.icenter, [angmom], [kind], shell.exponents, coeffs.reshape(-1, 1))
)
return attr.evolve(self, shells=shells)
return attrs.evolve(self, shells=shells)


def convert_convention_shell(
conv1: list[str], conv2: list[str], reverse=False
) -> tuple[np.ndarray, np.ndarray]:
) -> tuple[NDArray, NDArray]:
"""Return a permutation vector and sign changes to convert from 1 to 2.
The transformation from convention 1 to convention 2 can be done applying
Expand Down Expand Up @@ -289,7 +290,7 @@ def convert_convention_shell(

def convert_conventions(
molbasis: MolecularBasis, new_conventions: dict[str, list[str]], reverse=False
) -> tuple[np.ndarray, np.ndarray]:
) -> tuple[NDArray, NDArray]:
"""Return a permutation vector and sign changes to convert from 1 to 2.
The transformation from molbasis.convention to the new convention can be done
Expand Down Expand Up @@ -339,7 +340,7 @@ def convert_conventions(
return np.array(permutation), np.array(signs)


def iter_cart_alphabet(n: int) -> np.ndarray:
def iter_cart_alphabet(n: int) -> NDArray:
"""Loop over powers of Cartesian basis functions in alphabetical order.
See https://theochem.github.io/horton/2.1.1/tech_ref_gaussian_basis.html
Expand Down
3 changes: 2 additions & 1 deletion iodata/formats/chgcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_load_one
from ..periodic import sym2num
Expand All @@ -37,7 +38,7 @@
PATTERNS = ["CHGCAR*", "AECCAR*"]


def _load_vasp_header(lit: LineIterator) -> tuple[str, np.ndarray, np.ndarray, np.ndarray]:
def _load_vasp_header(lit: LineIterator) -> tuple[str, NDArray, NDArray, NDArray]:
"""Load the cell and atoms from a VASP file format.
Parameters
Expand Down
15 changes: 7 additions & 8 deletions iodata/formats/cp2klog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Union

import numpy as np
from numpy.typing import NDArray
from scipy.special import factorialk

from ..basis import HORTON2_CONVENTIONS, MolecularBasis, Shell, angmom_sti
Expand All @@ -42,9 +43,7 @@
}


def _get_cp2k_norm_corrections(
ell: int, alphas: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
def _get_cp2k_norm_corrections(ell: int, alphas: Union[float, NDArray]) -> Union[float, NDArray]:
"""Compute the corrections for the normalization of the basis functions.
This correction is needed because the CP2K atom code works with a different
Expand Down Expand Up @@ -236,7 +235,7 @@ def _read_cp2k_occupations_energies(

def _read_cp2k_orbital_coeffs(
lit: LineIterator, oe: list[tuple[int, int, float, float]]
) -> dict[tuple[int, int], np.ndarray]:
) -> dict[tuple[int, int], NDArray]:
"""Read the expansion coefficients of the orbital from an open CP2K ATOM output.
Parameters
Expand Down Expand Up @@ -294,11 +293,11 @@ def _get_norb_nel(oe: list[tuple[int, int, float, float]]) -> tuple[int, float]:


def _fill_orbitals(
orb_coeffs: np.ndarray,
orb_energies: np.ndarray,
orb_occupations: np.ndarray,
orb_coeffs: NDArray,
orb_energies: NDArray,
orb_occupations: NDArray,
oe: list[tuple[int, int, float, float]],
coeffs: dict[tuple[int, int], np.ndarray],
coeffs: dict[tuple[int, int], NDArray],
obasis: MolecularBasis,
restricted: bool,
):
Expand Down
19 changes: 10 additions & 9 deletions iodata/formats/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import TextIO

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_dump_one, document_load_one
from ..iodata import IOData
Expand All @@ -42,7 +43,7 @@

def _read_cube_header(
lit: LineIterator,
) -> tuple[str, np.ndarray, np.ndarray, np.ndarray, dict[str, np.ndarray], np.ndarray]:
) -> tuple[str, NDArray, NDArray, NDArray, dict[str, NDArray], NDArray]:
"""Load header data from a CUBE file object.
Parameters
Expand All @@ -62,7 +63,7 @@ def _read_cube_header(
# skip the second line
next(lit)

def read_grid_line(line: str) -> tuple[int, np.ndarray]:
def read_grid_line(line: str) -> tuple[int, NDArray]:
"""Read a grid line from the cube file."""
words = line.split()
return (
Expand All @@ -83,7 +84,7 @@ def read_grid_line(line: str) -> tuple[int, np.ndarray]:
cellvecs = axes * shape.reshape(-1, 1)
cube = {"origin": origin, "axes": axes, "shape": shape}

def read_atom_line(line: str) -> tuple[int, float, np.ndarray]:
def read_atom_line(line: str) -> tuple[int, float, NDArray]:
"""Read an atomic number and coordinate from the cube file."""
words = line.split()
return (
Expand All @@ -106,7 +107,7 @@ def read_atom_line(line: str) -> tuple[int, float, np.ndarray]:
return title, atcoords, atnums, cellvecs, cube, atcorenums


def _read_cube_data(lit: LineIterator, cube: dict[str, np.ndarray]):
def _read_cube_data(lit: LineIterator, cube: dict[str, NDArray]):
"""Load cube data from a CUBE file object.
Parameters
Expand Down Expand Up @@ -150,10 +151,10 @@ def load_one(lit: LineIterator) -> dict:
def _write_cube_header(
f: TextIO,
title: str,
atcoords: np.ndarray,
atnums: np.ndarray,
cube: dict[str, np.ndarray],
atcorenums: np.ndarray,
atcoords: NDArray,
atnums: NDArray,
cube: dict[str, NDArray],
atcorenums: NDArray,
):
print(title, file=f)
print("OUTER LOOP: X, MIDDLE LOOP: Y, INNER LOOP: Z", file=f)
Expand All @@ -169,7 +170,7 @@ def _write_cube_header(
print(f"{atnums[i]:5d} {q: 11.6f} {x: 11.6f} {y: 11.6f} {z: 11.6f}", file=f)


def _write_cube_data(f: TextIO, cube_data: np.ndarray, block_size: int):
def _write_cube_data(f: TextIO, cube_data: NDArray, block_size: int):
counter = 0
for value in cube_data.flat:
f.write(f" {value: 12.5E}")
Expand Down
7 changes: 4 additions & 3 deletions iodata/formats/fchk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Optional, TextIO

import numpy as np
from numpy.typing import NDArray

from ..basis import HORTON2_CONVENTIONS, MolecularBasis, Shell, convert_conventions
from ..docstrings import document_dump_one, document_load_many, document_load_one
Expand Down Expand Up @@ -473,7 +474,7 @@ def _load_dm(label: str, fchk: dict, result: dict, key: str):
result[key] = _triangle_to_dense(fchk[label])


def _triangle_to_dense(triangle: np.ndarray) -> np.ndarray:
def _triangle_to_dense(triangle: NDArray) -> NDArray:
"""Convert a symmetric matrix in triangular storage to a dense square matrix.
Parameters
Expand Down Expand Up @@ -512,7 +513,7 @@ def _dump_real_scalars(name: str, val: float, f: TextIO):
print(f"{name:40} R {float(val): 16.8E}", file=f)


def _dump_integer_arrays(name: str, val: np.ndarray, f: TextIO):
def _dump_integer_arrays(name: str, val: NDArray, f: TextIO):
"""Dumper for a array of integers."""
nval = val.size
if nval != 0:
Expand All @@ -527,7 +528,7 @@ def _dump_integer_arrays(name: str, val: np.ndarray, f: TextIO):
k = 0


def _dump_real_arrays(name: str, val: np.ndarray, f: TextIO):
def _dump_real_arrays(name: str, val: NDArray, f: TextIO):
"""Dumper for a array of float."""
nval = val.size
if nval != 0:
Expand Down
5 changes: 3 additions & 2 deletions iodata/formats/gamess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""GAMESS punch file format."""

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_load_one
from ..utils import LineIterator, angstrom
Expand Down Expand Up @@ -81,7 +82,7 @@ def _read_energy(lit: LineIterator, result: dict) -> tuple:
return energy, gradient


def _read_hessian(lit: LineIterator, result: dict) -> np.ndarray:
def _read_hessian(lit: LineIterator, result: dict) -> NDArray:
"""Extract ``hessian`` from the punch file."""
# check that $HESS is not already parsed
if "athessian" in result:
Expand All @@ -102,7 +103,7 @@ def _read_hessian(lit: LineIterator, result: dict) -> np.ndarray:
return hessian


def _read_masses(lit: LineIterator, result: dict) -> np.ndarray:
def _read_masses(lit: LineIterator, result: dict) -> NDArray:
"""Extract ``masses`` from the punch file."""
natom = len(result["symbols"])
masses = np.zeros(natom, float)
Expand Down
5 changes: 3 additions & 2 deletions iodata/formats/gaussianlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"""

import numpy as np
from numpy.typing import NDArray

from ..docstrings import document_load_one
from ..utils import LineIterator, set_four_index_element
Expand Down Expand Up @@ -73,7 +74,7 @@ def load_one(lit: LineIterator) -> dict:
return result


def _load_twoindex_g09(lit: LineIterator, nbasis: int) -> np.ndarray:
def _load_twoindex_g09(lit: LineIterator, nbasis: int) -> NDArray:
"""Load a two-index operator from a GAUSSIAN LOG file format.
Parameters
Expand Down Expand Up @@ -106,7 +107,7 @@ def _load_twoindex_g09(lit: LineIterator, nbasis: int) -> np.ndarray:
return result


def _load_fourindex_g09(lit: LineIterator, nbasis: int) -> np.ndarray:
def _load_fourindex_g09(lit: LineIterator, nbasis: int) -> NDArray:
"""Load a four-index operator from a GAUSSIAN LOG file.
Parameters
Expand Down
7 changes: 3 additions & 4 deletions iodata/formats/mol2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TextIO

import numpy as np
from numpy.typing import NDArray

from ..docstrings import (
document_dump_many,
Expand Down Expand Up @@ -83,9 +84,7 @@ def load_one(lit: LineIterator) -> dict:
return result


def _load_helper_atoms(
lit: LineIterator, natoms: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray, tuple]:
def _load_helper_atoms(lit: LineIterator, natoms: int) -> tuple[NDArray, NDArray, NDArray, tuple]:
"""Load element numbers, coordinates and atomic charges."""
atnums = np.empty(natoms)
atcoords = np.empty((natoms, 3))
Expand All @@ -112,7 +111,7 @@ def _load_helper_atoms(
return atnums, atcoords, atchgs, attypes


def _load_helper_bonds(lit: LineIterator, nbonds: int) -> tuple[np.ndarray]:
def _load_helper_bonds(lit: LineIterator, nbonds: int) -> NDArray:
"""Load bond information.
Each line in a bond definition has the following structure
Expand Down
15 changes: 7 additions & 8 deletions iodata/formats/molden.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
import copy
from typing import TextIO, Union

import attr
import attrs
import numpy as np
from numpy.typing import NDArray

from ..basis import (
HORTON2_CONVENTIONS,
Expand Down Expand Up @@ -225,9 +226,7 @@ def _load_low(lit: LineIterator) -> dict:
return result


def _load_helper_atoms(
lit: LineIterator, cunit: float
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def _load_helper_atoms(lit: LineIterator, cunit: float) -> tuple[NDArray, NDArray, NDArray]:
"""Load element numbers and coordinates."""
atnums = []
atcorenums = []
Expand Down Expand Up @@ -357,9 +356,9 @@ def _load_helper_coeffs(lit: LineIterator) -> tuple:

def _is_normalized_properly(
obasis: MolecularBasis,
atcoords: np.ndarray,
orb_alpha: np.ndarray,
orb_beta: np.ndarray,
atcoords: NDArray,
orb_alpha: NDArray,
orb_beta: NDArray,
norm_threshold: float = 1e-4,
) -> bool:
"""Test the normalization of the occupied and virtual orbitals.
Expand Down Expand Up @@ -551,7 +550,7 @@ def _fix_obasis_normalize_contractions(obasis: MolecularBasis) -> MolecularBasis
fixed_shells = []
for shell in obasis.shells:
shell_obasis = MolecularBasis(
[attr.evolve(shell, icenter=0)], obasis.conventions, obasis.primitive_normalization
[attrs.evolve(shell, icenter=0)], obasis.conventions, obasis.primitive_normalization
)
# 2) Get the first diagonal element of the overlap matrix
olpdiag = compute_overlap(shell_obasis, np.zeros((1, 3), float))[0, 0]
Expand Down
Loading

0 comments on commit 99e0f94

Please sign in to comment.