Skip to content

Commit

Permalink
Add typer (#51)
Browse files Browse the repository at this point in the history
* Make kwargs explicitly for calculator

* Add typer as dependency

* Add initial CLI commands

* Add CLI as script

* Tidy docstrings

* Add tests for singlepoint CLI

* Update docs

* Replace system with structure

---------

Co-authored-by: ElliottKasoar <ElliottKasoar@users.noreply.github.com>
  • Loading branch information
ElliottKasoar and ElliottKasoar authored Mar 5, 2024
1 parent 10c0c19 commit 4331cf2
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 106 deletions.
10 changes: 10 additions & 0 deletions docs/source/apidoc/janus_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ janus\_core package
Submodules
----------

janus\_core.cli module
----------------------

.. automodule:: janus_core.cli
:members:
:special-members:
:private-members:
:undoc-members:
:show-inheritance:

janus\_core.geom\_opt module
----------------------------

Expand Down
143 changes: 143 additions & 0 deletions janus_core/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Set up commandline interface."""

import ast
from typing import Annotated

import typer

from janus_core.single_point import SinglePoint

app = typer.Typer()


class TyperDict: # pylint: disable=too-few-public-methods
"""
Custom dictionary for typer.
Parameters
----------
value : str
Value of string representing a dictionary.
"""

def __init__(self, value: str):
"""
Initialise class.
Parameters
----------
value : str
Value of string representing a dictionary.
"""
self.value = value

def __str__(self):
"""
String representation of class.
Returns
-------
str
Class name and value of string representing a dictionary.
"""
return f"<TyperDict: value={self.value}>"


def parse_dict_class(value: str):
"""
Convert string input into a dictionary.
Parameters
----------
value : str
String representing dictionary to be parsed.
Returns
-------
TyperDict
Parsed string as a dictionary.
"""
return TyperDict(ast.literal_eval(value))


@app.command()
def singlepoint(
structure: Annotated[
str, typer.Option(help="Path to structure to perform calculations")
],
architecture: Annotated[
str, typer.Option("--arch", help="MLIP architecture to use for calculations")
] = "mace_mp",
device: Annotated[str, typer.Option(help="Device to run calculations on")] = "cpu",
properties: Annotated[
list[str],
typer.Option(
"--property",
help="Properties to calculate. If not specified, 'energy', 'forces', and 'stress' will be returned.",
),
] = None,
read_kwargs: Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
help="Keyword arguments to pass to ase.io.read [default: {}]",
metavar="DICT",
),
] = None,
calc_kwargs: Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
help="Keyword arguments to pass to selected calculator [default: {}]",
metavar="DICT",
),
] = None,
):
"""
Perform single point calculations.
Parameters
----------
structure : str
Structure to simulate.
architecture : Optional[str]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
device : Optional[str]
Device to run model on. Default is "cpu".
properties : Optional[str]
Physical properties to calculate. Default is "energy".
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
"""
read_kwargs = read_kwargs.value if read_kwargs else {}
calc_kwargs = calc_kwargs.value if calc_kwargs else {}

if not isinstance(read_kwargs, dict):
raise ValueError("read_kwargs must be a dictionary")
if not isinstance(calc_kwargs, dict):
raise ValueError("calc_kwargs must be a dictionary")

s_point = SinglePoint(
structure=structure,
architecture=architecture,
device=device,
read_kwargs=read_kwargs,
calc_kwargs=calc_kwargs,
)
print(s_point.run_single_point(properties=properties))


@app.command()
def test(name: str):
"""
Dummy alternative CLI command.
Parameters
----------
name : str
Name of person.
"""
print(f"Hello, {name}!")
91 changes: 48 additions & 43 deletions janus_core/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,74 +15,79 @@ class SinglePoint:
Parameters
----------
system : str
System to simulate.
structure : str
Structure to simulate.
architecture : Literal[architectures]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
device : Literal[devices]
Device to run model on. Default is "cpu".
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
**kwargs
Additional keyword arguments passed to the selected calculator.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
Attributes
----------
architecture : Literal[architectures]
MLIP architecture to use for single point calculations.
system : str
System to simulate.
structure : str
Path of structure to simulate.
device : Literal[devices]
Device to run MLIP model on.
struct : Union[Atoms, list[Atoms]
ASE Atoms or list of Atoms structures to simulate.
structname : str
Name of structure from its filename.
Methods
-------
read_system(**kwargs)
Read system and system name.
read_structure(**kwargs)
Read structure and structure name.
set_calculator(**kwargs)
Configure calculator and attach to system.
Configure calculator and attach to structure.
run_single_point(properties=None)
Run single point calculations.
"""

def __init__(
self,
system: str,
structure: str,
architecture: Literal[architectures] = "mace_mp",
device: Literal[devices] = "cpu",
read_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
calc_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Read the system being simulated and attach an MLIP calculator.
Read the structure being simulated and attach an MLIP calculator.
Parameters
----------
system : str
System to simulate.
structure : str
Path of structure to simulate.
architecture : Literal[architectures]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
device : Literal[devices]
Device to run MLIP model on. Default is "cpu".
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
**kwargs
Additional keyword arguments passed to the selected calculator.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
"""
self.architecture = architecture
self.device = device
self.system = system
self.structure = structure

# Read system and get calculator
# Read structure and get calculator
read_kwargs = read_kwargs if read_kwargs else {}
self.read_system(**read_kwargs)
self.set_calculator(**kwargs)
calc_kwargs = calc_kwargs if calc_kwargs else {}
self.read_structure(**read_kwargs)
self.set_calculator(**calc_kwargs)

def read_system(self, **kwargs) -> None:
def read_structure(self, **kwargs) -> None:
"""
Read system and system name.
Read structure and structure name.
If the file contains multiple structures, only the last configuration
will be read by default.
Expand All @@ -92,14 +97,14 @@ def read_system(self, **kwargs) -> None:
**kwargs
Keyword arguments passed to ase.io.read.
"""
self.sys = read(self.system, **kwargs)
self.sysname = pathlib.Path(self.system).stem
self.struct = read(self.structure, **kwargs)
self.structname = pathlib.Path(self.structure).stem

def set_calculator(
self, read_kwargs: Optional[dict[str, Any]] = None, **kwargs
) -> None:
"""
Configure calculator and attach to system.
Configure calculator and attach to structure.
Parameters
----------
Expand All @@ -113,15 +118,15 @@ def set_calculator(
device=self.device,
**kwargs,
)
if self.sys is None:
if self.struct is None:
read_kwargs = read_kwargs if read_kwargs else {}
self.read_system(**read_kwargs)
self.read_structure(**read_kwargs)

if isinstance(self.sys, list):
for sys in self.sys:
sys.calc = calculator
if isinstance(self.struct, list):
for struct in self.struct:
struct.calc = calculator
else:
self.sys.calc = calculator
self.struct.calc = calculator

def _get_potential_energy(self) -> Union[float, list[float]]:
"""
Expand All @@ -130,12 +135,12 @@ def _get_potential_energy(self) -> Union[float, list[float]]:
Returns
-------
Union[float, list[float]]
Potential energy of system(s).
Potential energy of structure(s).
"""
if isinstance(self.sys, list):
return [sys.get_potential_energy() for sys in self.sys]
if isinstance(self.struct, list):
return [struct.get_potential_energy() for struct in self.struct]

return self.sys.get_potential_energy()
return self.struct.get_potential_energy()

def _get_forces(self) -> Union[ndarray, list[ndarray]]:
"""
Expand All @@ -144,12 +149,12 @@ def _get_forces(self) -> Union[ndarray, list[ndarray]]:
Returns
-------
Union[ndarray, list[ndarray]]
Forces of system(s).
Forces of structure(s).
"""
if isinstance(self.sys, list):
return [sys.get_forces() for sys in self.sys]
if isinstance(self.struct, list):
return [struct.get_forces() for struct in self.struct]

return self.sys.get_forces()
return self.struct.get_forces()

def _get_stress(self) -> Union[ndarray, list[ndarray]]:
"""
Expand All @@ -158,12 +163,12 @@ def _get_stress(self) -> Union[ndarray, list[ndarray]]:
Returns
-------
Union[ndarray, list[ndarray]]
Stress of system(s).
Stress of structure(s).
"""
if isinstance(self.sys, list):
return [sys.get_stress() for sys in self.sys]
if isinstance(self.struct, list):
return [struct.get_stress() for struct in self.struct]

return self.sys.get_stress()
return self.struct.get_stress()

def run_single_point(
self, properties: Optional[Union[str, list[str]]] = None
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ classifiers = [
repository = "https://github.com/stfc/janus-core/"
documentation = "https://stfc.github.io/janus-core/"

[tool.poetry.scripts]
janus = "janus_core.cli:app"

[tool.poetry.dependencies]
python = "^3.9"
ase = "^3.22.1"
mace-torch = "^0.3.4"
typer = "^0.9.0"

[tool.poetry.group.dev.dependencies]
coverage = {extras = ["toml"], version = "^7.4.1"}
Expand Down
Loading

0 comments on commit 4331cf2

Please sign in to comment.