Skip to content

Commit

Permalink
Convert yaml configuration hyphens to underscores (#113)
Browse files Browse the repository at this point in the history
Co-authored-by: ElliottKasoar <ElliottKasoar@users.noreply.github.com>
  • Loading branch information
ElliottKasoar and ElliottKasoar authored Apr 11, 2024
1 parent 9f9ebc1 commit f4f9d01
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 31 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ properties:
- "energy"
out: "NaCl-results.xyz"
arch: mace_mp
calc_kwargs:
calc-kwargs:
model: medium
```
Expand All @@ -187,10 +187,6 @@ This will run a singlepoint energy calculation on `KCl.cif` using the [MACE-MP](
> [!NOTE]
> `properties` must be passed as a Yaml list, as above, not as a string.
> [!WARNING]
> Options in the Yaml file must use `_` instead of `-`.
> For example, `calc_kwargs` should be used in the configuration file for the `--calc-kwargs` option.
> [!WARNING]
> If an option in the configuration file does not match any variable names, an error will **not** be raised.
> Please check the summary file to ensure the configuration has been read correctly.
Expand Down
43 changes: 26 additions & 17 deletions janus_core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
import datetime
import logging
from pathlib import Path
from typing import Annotated, Optional, Union, get_args
from typing import Annotated, Any, Optional, Union, get_args

from ase import Atoms
import typer
from typer_config import use_yaml_config
from typer_config import conf_callback_factory, use_config, yaml_loader
import yaml

from janus_core import __version__
from janus_core.geom_opt import optimize
from janus_core.janus_types import ASEReadArgs, Ensembles
from janus_core.md import NPH, NPT, NVE, NVT, NVT_NH
from janus_core.single_point import SinglePoint
from janus_core.utils import dict_paths_to_strs, dict_remove_hyphens

app = typer.Typer(name="janus", no_args_is_help=True)

Expand Down Expand Up @@ -101,21 +102,29 @@ def _parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
return typer_dicts


def _dict_paths_to_strs(dictionary: dict) -> None:
def yaml_converter_loader(config_file: str) -> dict[str, Any]:
"""
Recursively iterate over dictionary, converting Path values to strings.
Load yaml configuration and replace hyphens with underscores.
Parameters
----------
dictionary : dict
Dictionary to be converted.
config_file : str
Yaml configuration file to read.
Returns
-------
dict[str, Any]
Dictionary with loaded configuration.
"""
for key, value in dictionary.items():
if isinstance(value, dict):
_dict_paths_to_strs(value)
elif isinstance(value, Path):
dictionary[key] = str(value)
if not config_file:
return {}

config = yaml_loader(config_file)
# Replace all "-"" with "_" in conf
return dict_remove_hyphens(config)


yaml_converter_callback = conf_callback_factory(yaml_converter_loader)

# Shared type aliases
StructPath = Annotated[Path, typer.Option(help="Path of structure to simulate.")]
Expand Down Expand Up @@ -190,7 +199,7 @@ def print_version(


@app.command(help="Perform single point calculations and save to file.")
@use_yaml_config()
@use_config(yaml_converter_callback)
def singlepoint(
# pylint: disable=too-many-locals
# numpydoc ignore=PR02
Expand Down Expand Up @@ -308,7 +317,7 @@ def singlepoint(
}

# Convert all paths to strings in inputs nested dictionary
_dict_paths_to_strs(inputs)
dict_paths_to_strs(inputs)

# Save summary information before singlepoint calculation begins
save_info = {
Expand All @@ -335,7 +344,7 @@ def singlepoint(
@app.command(
help="Perform geometry optimization and save optimized structure to file.",
)
@use_yaml_config()
@use_config(yaml_converter_callback)
def geomopt(
# pylint: disable=too-many-arguments,too-many-locals
# numpydoc ignore=PR02
Expand Down Expand Up @@ -509,7 +518,7 @@ def geomopt(
}

# Convert all paths to strings in inputs nested dictionary
_dict_paths_to_strs(inputs)
dict_paths_to_strs(inputs)

# Save summary information before optimization begins
save_info = {
Expand All @@ -536,7 +545,7 @@ def geomopt(
@app.command(
help="Run molecular dynamics simulation, and save trajectory and statistics.",
)
@use_yaml_config()
@use_config(yaml_converter_callback)
def md(
# pylint: disable=too-many-arguments,too-many-locals,invalid-name
# numpydoc ignore=PR02
Expand Down Expand Up @@ -895,7 +904,7 @@ def md(
}

# Convert all paths to strings in inputs nested dictionary
_dict_paths_to_strs(inputs)
dict_paths_to_strs(inputs)

# Save summary information before simulation begins
save_info = {
Expand Down
4 changes: 2 additions & 2 deletions janus_core/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _reset_velocities(self) -> None:

def _optimize_structure(self) -> None:
"""Perform geometry optimization."""
if self.dyn.nsteps < self.equil_steps:
if self.dyn.nsteps == 0 or self.dyn.nsteps < self.equil_steps:
if self.logger:
self.minimize_kwargs["log_kwargs"] = {
"filename": self.log_kwargs["filename"],
Expand Down Expand Up @@ -502,7 +502,7 @@ def run(self) -> None:

else:
if self.minimize:
optimize(self.struct, **self.minimize_kwargs)
self._optimize_structure()
if self.rescale_velocities:
self._reset_velocities()

Expand Down
37 changes: 37 additions & 0 deletions janus_core/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility functions for janus_core."""

from pathlib import Path
from typing import Optional


Expand All @@ -20,3 +21,39 @@ def none_to_dict(dictionaries: list[Optional[dict]]) -> list[dict]:
for i, dictionary in enumerate(dictionaries):
dictionaries[i] = dictionary if dictionary else {}
return dictionaries


def dict_paths_to_strs(dictionary: dict) -> None:
"""
Recursively iterate over dictionary, converting Path values to strings.
Parameters
----------
dictionary : dict
Dictionary to be converted.
"""
for key, value in dictionary.items():
if isinstance(value, dict):
dict_paths_to_strs(value)
elif isinstance(value, Path):
dictionary[key] = str(value)


def dict_remove_hyphens(dictionary: dict) -> dict:
"""
Recursively iterate over dictionary, replacing hyphens with underscores in keys.
Parameters
----------
dictionary : dict
Dictionary to be converted.
Returns
-------
dict
Dictionary with hyphens in keys replaced with underscores.
"""
for key, value in dictionary.items():
if isinstance(value, dict):
dictionary[key] = dict_remove_hyphens(value)
return {k.replace("-", "_"): v for k, v in dictionary.items()}
6 changes: 3 additions & 3 deletions tests/data/geomopt_config.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct_path: "NaCl.cif"
out_file: "NaCl-results.xyz"
opt_kwargs:
struct: "NaCl.cif"
out: "NaCl-results.xyz"
opt-kwargs:
alpha: 100
3 changes: 3 additions & 0 deletions tests/data/md_config.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
ensemble: "nvt"
temp: 200
minimize-kwargs:
filter-kwargs:
hydrostatic-strain: True
8 changes: 4 additions & 4 deletions tests/data/singlepoint_config.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
struct_path: "NaCl.cif"
struct: "NaCl.cif"
properties:
- "energy"
out_file: "NaCl-results.xyz"
calc_kwargs:
out: "NaCl-results.xyz"
calc-kwargs:
model: "small"
read_kwargs:
read-kwargs:
index: ":"
9 changes: 9 additions & 0 deletions tests/test_md_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def test_config(tmp_path):
file_prefix,
"--steps",
2,
"--minimize",
"--log",
log_path,
"--summary",
Expand All @@ -278,6 +279,14 @@ def test_config(tmp_path):
assert md_summary["inputs"]["temp"] == 200
# Check explicit option overwrites config
assert md_summary["inputs"]["ensemble"] == "nve"
# Check nested dictionary
assert (
md_summary["inputs"]["minimize_kwargs"]["filter_kwargs"]["hydrostatic_strain"]
is True
)

# Check hydrostatic strain passed correctly
assert_log_contains(log_path, includes=["hydrostatic_strain: True"])


def test_heating(tmp_path):
Expand Down
48 changes: 48 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Test utility functions."""

from pathlib import Path

from janus_core.utils import dict_paths_to_strs, dict_remove_hyphens


def test_dict_paths_to_strs():
"""Test Paths are converted to strings."""
dictionary = {
"key1": Path("/example/path"),
"key2": {
"key3": Path("another/example"),
"key4": "example",
},
}

# Check Paths are present
assert isinstance(dictionary["key1"], Path)
assert isinstance(dictionary["key2"]["key3"], Path)
assert not isinstance(dictionary["key1"], str)

dict_paths_to_strs(dictionary)

# Check Paths are now strings
assert isinstance(dictionary["key1"], str)
assert isinstance(dictionary["key2"]["key3"], str)


def test_dict_remove_hyphens():
"""Test hyphens are replaced with underscores."""
dictionary = {
"key-1": "value_1",
"key-2": {
"key-3": "value-3",
"key-4": 4,
"key_5": 5.0,
"key6": {"key-7": "value7"},
},
}
dictionary = dict_remove_hyphens(dictionary)

# Check hyphens are now strings
assert dictionary["key_1"] == "value_1"
assert dictionary["key_2"]["key_3"] == "value-3"
assert dictionary["key_2"]["key_4"] == 4
assert dictionary["key_2"]["key_5"] == 5.0
assert dictionary["key_2"]["key6"]["key_7"] == "value7"

0 comments on commit f4f9d01

Please sign in to comment.