Skip to content

Commit

Permalink
Add extended typing through janus_types module
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 committed Feb 29, 2024
1 parent a100b69 commit aa16d61
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 41 deletions.
25 changes: 13 additions & 12 deletions janus_core/geom_opt.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
"""Geometry optimization."""

from typing import Any, Optional
from typing import Any, Optional, Callable

from ase import Atoms
from ase.io import read, write
from ase.optimize import LBFGS

try:
from ase.filters import FrechetCellFilter as DefaultFilter
except ImportError:
from ase.constraints import ExpCellFilter as DefaultFilter

from ase.optimize import LBFGS
from .janus_types import ASEOptArgs, ASEOptRunArgs, ASEWriteArgs


def optimize(
atoms: Atoms,
fmax: float = 0.1,
dyn_kwargs: Optional[dict[str, Any]] = None,
filter_func: Optional[callable] = DefaultFilter,
dyn_kwargs: Optional[ASEOptRunArgs] = None,
filter_func: Optional[Callable] = DefaultFilter,
filter_kwargs: Optional[dict[str, Any]] = None,
optimizer: callable = LBFGS,
opt_kwargs: Optional[dict[str, Any]] = None,
struct_kwargs: Optional[dict[str, Any]] = None,
traj_kwargs: Optional[dict[str, Any]] = None,
optimizer: Callable = LBFGS,
opt_kwargs: Optional[ASEOptArgs] = None,
struct_kwargs: Optional[ASEWriteArgs] = None,
traj_kwargs: Optional[ASEWriteArgs] = None,
) -> Atoms:
"""Optimize geometry of input structure.
Expand All @@ -32,7 +33,7 @@ def optimize(
Atoms object to optimize geometry for.
fmax : float
Set force convergence criteria for optimizer in units eV/Å.
dyn_kwargs : Optional[dict[str, Any]]
dyn_kwargs : Optional[ASEOptRunArgs]
kwargs to pass to dyn.run. Default is {}.
filter_func : Optional[callable]
Apply constraints to atoms through ASE filter function.
Expand All @@ -41,12 +42,12 @@ def optimize(
kwargs to pass to filter_func. Default is {}.
optimzer : callable
ASE optimization function. Default is `LBFGS`.
opt_kwargs : Optional[dict[str, Any]]
opt_kwargs : Optional[ASEOptArgs]
kwargs to pass to optimzer. Default is {}.
struct_kwargs : Optional[dict[str, Any]]
struct_kwargs : Optional[ASEWriteArgs]
kwargs to pass to ase.io.write to save optimized structure.
Must include "filename" keyword. Default is {}.
traj_kwargs : Optional[dict[str, Any]]
traj_kwargs : Optional[ASEWriteArgs]
kwargs to pass to ase.io.write to save optimization trajectory.
Must include "filename" keyword. Default is {}.
Expand Down
62 changes: 62 additions & 0 deletions janus_core/janus_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Module containing types used in Janus-Core
"""
from pathlib import Path, PurePath
from typing import IO, List, Literal, Optional, Sequence, TypedDict, TypeVar, Union

from ase import Atoms
import numpy as np
from numpy.typing import NDArray


# General

T = TypeVar("T")
MaybeList = Union[T, List[T]]
MaybeSequence = Union[T, Sequence[T]]
PathLike = Union[str, Path]


# ASE Arg types

class ASEReadArgs(TypedDict, total=False):
"""Main arguments for ase.io.read"""
filename: Union[str, PurePath, IO]
index: Union[int, slice, str]
format: Optional[str]
parallel: bool
do_not_split_by_at_sign: bool


class ASEWriteArgs(TypedDict, total=False):
"""Main arguments for ase.io.write"""
filename: Union[str, PurePath, IO]
images: MaybeSequence[Atoms]
format: Optional[str]
parallel: bool
append: bool


class ASEOptArgs(TypedDict, total=False):
"""Main arugments for ase optimisers"""
restart: Optional[bool]
logfile: Optional[PathLike]
trajectory: PathLike


class ASEOptRunArgs(TypedDict, total=False):
"""Main arugments for running ase optimisers"""
fmax: float
steps: int


# Janus specific
Architectures = Literal["mace", "mace_mp", "mace_off", "m3gnet", "chgnet"]
Devices = Literal["cpu", "cuda", "mps"]


class CalcResults(TypedDict, total=False):
"""Return type from calculations"""
energy: MaybeList[float]
forces: MaybeList[NDArray[np.float64]]
stress: MaybeList[NDArray[np.float64]]
16 changes: 8 additions & 8 deletions janus_core/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,24 @@
- https://github.com/Quantum-Accelerators/quacc.git
"""

from typing import Literal

from ase.calculators.calculator import Calculator

architectures = ["mace", "mace_mp", "mace_off", "m3gnet", "chgnet"]
devices = ["cpu", "cuda", "mps"]
from .janus_types import Architectures, Devices


def choose_calculator(
architecture: Literal[architectures] = "mace",
device: Literal[devices] = "cuda",
architecture: Architectures = "mace",
device: Devices = "cuda",
**kwargs,
) -> Calculator:
"""Choose MLIP calculator to configure.
Parameters
----------
architecture : Literal[architectures], optional
architecture : Architectures, optional
MLIP architecture. Default is "mace".
device: Devices, optional
Device to run on. Default is "cuda"
Raises
------
Expand Down Expand Up @@ -79,7 +78,8 @@ def choose_calculator(

else:
raise ValueError(
f"Unrecognized {architecture=}. Suported architectures are {architectures}"
f"Unrecognized {architecture=}. Suported architectures "
f"are {', '.join(Architectures.__args__)}"
)

calculator.parameters["version"] = __version__
Expand Down
42 changes: 21 additions & 21 deletions janus_core/single_point.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Perpare and perform single point calculations."""

import pathlib
from typing import Any, Literal, Optional, Union
from typing import Optional

from ase.io import read
from numpy import ndarray

from janus_core.mlip_calculators import architectures, choose_calculator, devices
from janus_core.mlip_calculators import choose_calculator
from .janus_types import Architectures, ASEReadArgs, CalcResults, Devices, MaybeList, MaybeSequence


class SinglePoint:
Expand All @@ -15,9 +16,9 @@ class SinglePoint:
def __init__(
self,
system: str,
architecture: Literal[architectures] = "mace_mp",
device: Literal[devices] = "cpu",
read_kwargs: Optional[dict[str, Any]] = None,
architecture: Architectures = "mace_mp",
device: Devices = "cpu",
read_kwargs: Optional[ASEReadArgs] = None,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -54,13 +55,13 @@ def read_system(self, **kwargs) -> None:
self.sysname = pathlib.Path(self.system).stem

def set_calculator(
self, read_kwargs: Optional[dict[str, Any]] = None, **kwargs
self, read_kwargs: Optional[ASEReadArgs] = None, **kwargs
) -> None:
"""Configure calculator and attach to system.
Parameters
----------
read_kwargs : Optional[dict[str, Any]]
read_kwargs : Optional[ASEReadArgs]
kwargs to pass to ase.io.read. Default is {}.
"""
calculator = choose_calculator(
Expand All @@ -72,44 +73,45 @@ def set_calculator(
read_kwargs = read_kwargs if read_kwargs else {}
self.read_system(**read_kwargs)

if isinstance(self.sys, list):
elif isinstance(self.sys, list):
for sys in self.sys:
sys.calc = calculator

else:
self.sys.calc = calculator

def _get_potential_energy(self) -> Union[float, list[float]]:
def _get_potential_energy(self) -> MaybeList[float]:
"""Calculate potential energy using MLIP.
Returns
-------
potential_energy : Union[float, list[float]]
potential_energy : MaybeList[float]
Potential energy of system(s).
"""
if isinstance(self.sys, list):
return [sys.get_potential_energy() for sys in self.sys]

return self.sys.get_potential_energy()

def _get_forces(self) -> Union[ndarray, list[ndarray]]:
def _get_forces(self) -> MaybeList[ndarray]:
"""Calculate forces using MLIP.
Returns
-------
forces : Union[ndarray, list[ndarray]]
forces : MaybeList[ndarray]
Forces of system(s).
"""
if isinstance(self.sys, list):
return [sys.get_forces() for sys in self.sys]

return self.sys.get_forces()

def _get_stress(self) -> Union[ndarray, list[ndarray]]:
def _get_stress(self) -> MaybeList[ndarray]:
"""Calculate stress using MLIP.
Returns
-------
stress : Union[ndarray, list[ndarray]]
stress : MaybeList[ndarray]
Stress of system(s).
"""
if isinstance(self.sys, list):
Expand All @@ -118,24 +120,22 @@ def _get_stress(self) -> Union[ndarray, list[ndarray]]:
return self.sys.get_stress()

def run_single_point(
self, properties: Optional[Union[str, list[str]]] = None
) -> dict[str, Any]:
self, properties: MaybeSequence[str] = ()
) -> CalcResults:
"""Run single point calculations.
Parameters
----------
properties : Optional[Union[str, list[str]]]
properties : MaybeSequence[str]
Physical properties to calculate. If not specified, "energy",
"forces", and "stress" will be returned.
Returns
-------
results : dict[str, Any]
results : CalcResults
Dictionary of calculated results.
"""
results = {}
if properties is None:
properties = []
results: CalcResults = {}
if isinstance(properties, str):
properties = [properties]

Expand Down

0 comments on commit aa16d61

Please sign in to comment.