Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typer #51

Merged
merged 8 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/apidoc/janus_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ janus\_core package
Submodules
----------

janus\_core.cli module
----------------------

.. automodule:: janus_core.cli
:members:
:special-members:
:private-members:
:undoc-members:
:show-inheritance:

janus\_core.geom\_opt module
----------------------------

Expand Down
143 changes: 143 additions & 0 deletions janus_core/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Set up commandline interface."""

import ast
from typing import Annotated

import typer

from janus_core.single_point import SinglePoint

app = typer.Typer()


class TyperDict: # pylint: disable=too-few-public-methods
"""
Custom dictionary for typer.

Parameters
----------
value : str
Value of string representing a dictionary.
"""

def __init__(self, value: str):
"""
Initialise class.

Parameters
----------
value : str
Value of string representing a dictionary.
"""
self.value = value

def __str__(self):
"""
String representation of class.

Returns
-------
str
Class name and value of string representing a dictionary.
"""
return f"<TyperDict: value={self.value}>"


def parse_dict_class(value: str):
"""
Convert string input into a dictionary.

Parameters
----------
value : str
String representing dictionary to be parsed.

Returns
-------
TyperDict
Parsed string as a dictionary.
"""
return TyperDict(ast.literal_eval(value))


@app.command()
def singlepoint(
structure: Annotated[
str, typer.Option(help="Path to structure to perform calculations")
],
architecture: Annotated[
str, typer.Option("--arch", help="MLIP architecture to use for calculations")
] = "mace_mp",
device: Annotated[str, typer.Option(help="Device to run calculations on")] = "cpu",
properties: Annotated[
list[str],
typer.Option(
"--property",
help="Properties to calculate. If not specified, 'energy', 'forces', and 'stress' will be returned.",
),
] = None,
read_kwargs: Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
help="Keyword arguments to pass to ase.io.read [default: {}]",
metavar="DICT",
),
] = None,
calc_kwargs: Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
help="Keyword arguments to pass to selected calculator [default: {}]",
metavar="DICT",
),
] = None,
):
"""
Perform single point calculations.

Parameters
----------
structure : str
Structure to simulate.
architecture : Optional[str]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
device : Optional[str]
Device to run model on. Default is "cpu".
properties : Optional[str]
Physical properties to calculate. Default is "energy".
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
"""
read_kwargs = read_kwargs.value if read_kwargs else {}
calc_kwargs = calc_kwargs.value if calc_kwargs else {}

if not isinstance(read_kwargs, dict):
raise ValueError("read_kwargs must be a dictionary")
if not isinstance(calc_kwargs, dict):
raise ValueError("calc_kwargs must be a dictionary")

s_point = SinglePoint(
structure=structure,
architecture=architecture,
device=device,
read_kwargs=read_kwargs,
calc_kwargs=calc_kwargs,
)
print(s_point.run_single_point(properties=properties))


@app.command()
def test(name: str):
"""
Dummy alternative CLI command.

Parameters
----------
name : str
Name of person.
"""
print(f"Hello, {name}!")
91 changes: 48 additions & 43 deletions janus_core/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,74 +15,79 @@ class SinglePoint:

Parameters
----------
system : str
System to simulate.
structure : str
Structure to simulate.
architecture : Literal[architectures]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
device : Literal[devices]
Device to run model on. Default is "cpu".
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
**kwargs
Additional keyword arguments passed to the selected calculator.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.

Attributes
----------
architecture : Literal[architectures]
MLIP architecture to use for single point calculations.
system : str
System to simulate.
structure : str
Path of structure to simulate.
device : Literal[devices]
Device to run MLIP model on.
struct : Union[Atoms, list[Atoms]
ASE Atoms or list of Atoms structures to simulate.
structname : str
Name of structure from its filename.

Methods
-------
read_system(**kwargs)
Read system and system name.
read_structure(**kwargs)
Read structure and structure name.
set_calculator(**kwargs)
Configure calculator and attach to system.
Configure calculator and attach to structure.
run_single_point(properties=None)
Run single point calculations.
"""

def __init__(
self,
system: str,
structure: str,
architecture: Literal[architectures] = "mace_mp",
device: Literal[devices] = "cpu",
read_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
calc_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Read the system being simulated and attach an MLIP calculator.
Read the structure being simulated and attach an MLIP calculator.

Parameters
----------
system : str
System to simulate.
structure : str
Path of structure to simulate.
architecture : Literal[architectures]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
device : Literal[devices]
Device to run MLIP model on. Default is "cpu".
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
**kwargs
Additional keyword arguments passed to the selected calculator.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
"""
self.architecture = architecture
self.device = device
self.system = system
self.structure = structure

# Read system and get calculator
# Read structure and get calculator
read_kwargs = read_kwargs if read_kwargs else {}
self.read_system(**read_kwargs)
self.set_calculator(**kwargs)
calc_kwargs = calc_kwargs if calc_kwargs else {}
self.read_structure(**read_kwargs)
self.set_calculator(**calc_kwargs)

def read_system(self, **kwargs) -> None:
def read_structure(self, **kwargs) -> None:
"""
Read system and system name.
Read structure and structure name.

If the file contains multiple structures, only the last configuration
will be read by default.
Expand All @@ -92,14 +97,14 @@ def read_system(self, **kwargs) -> None:
**kwargs
Keyword arguments passed to ase.io.read.
"""
self.sys = read(self.system, **kwargs)
self.sysname = pathlib.Path(self.system).stem
self.struct = read(self.structure, **kwargs)
self.structname = pathlib.Path(self.structure).stem

def set_calculator(
self, read_kwargs: Optional[dict[str, Any]] = None, **kwargs
) -> None:
"""
Configure calculator and attach to system.
Configure calculator and attach to structure.

Parameters
----------
Expand All @@ -113,15 +118,15 @@ def set_calculator(
device=self.device,
**kwargs,
)
if self.sys is None:
if self.struct is None:
read_kwargs = read_kwargs if read_kwargs else {}
self.read_system(**read_kwargs)
self.read_structure(**read_kwargs)

if isinstance(self.sys, list):
for sys in self.sys:
sys.calc = calculator
if isinstance(self.struct, list):
for struct in self.struct:
struct.calc = calculator
else:
self.sys.calc = calculator
self.struct.calc = calculator

def _get_potential_energy(self) -> Union[float, list[float]]:
"""
Expand All @@ -130,12 +135,12 @@ def _get_potential_energy(self) -> Union[float, list[float]]:
Returns
-------
Union[float, list[float]]
Potential energy of system(s).
Potential energy of structure(s).
"""
if isinstance(self.sys, list):
return [sys.get_potential_energy() for sys in self.sys]
if isinstance(self.struct, list):
return [struct.get_potential_energy() for struct in self.struct]

return self.sys.get_potential_energy()
return self.struct.get_potential_energy()

def _get_forces(self) -> Union[ndarray, list[ndarray]]:
"""
Expand All @@ -144,12 +149,12 @@ def _get_forces(self) -> Union[ndarray, list[ndarray]]:
Returns
-------
Union[ndarray, list[ndarray]]
Forces of system(s).
Forces of structure(s).
"""
if isinstance(self.sys, list):
return [sys.get_forces() for sys in self.sys]
if isinstance(self.struct, list):
return [struct.get_forces() for struct in self.struct]

return self.sys.get_forces()
return self.struct.get_forces()

def _get_stress(self) -> Union[ndarray, list[ndarray]]:
"""
Expand All @@ -158,12 +163,12 @@ def _get_stress(self) -> Union[ndarray, list[ndarray]]:
Returns
-------
Union[ndarray, list[ndarray]]
Stress of system(s).
Stress of structure(s).
"""
if isinstance(self.sys, list):
return [sys.get_stress() for sys in self.sys]
if isinstance(self.struct, list):
return [struct.get_stress() for struct in self.struct]

return self.sys.get_stress()
return self.struct.get_stress()

def run_single_point(
self, properties: Optional[Union[str, list[str]]] = None
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ classifiers = [
repository = "https://github.com/stfc/janus-core/"
documentation = "https://stfc.github.io/janus-core/"

[tool.poetry.scripts]
janus = "janus_core.cli:app"

[tool.poetry.dependencies]
python = "^3.9"
ase = "^3.22.1"
mace-torch = "^0.3.4"
typer = "^0.9.0"

[tool.poetry.group.dev.dependencies]
coverage = {extras = ["toml"], version = "^7.4.1"}
Expand Down
Loading