Skip to content

Commit

Permalink
Add option to save optimization trajectory
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Feb 22, 2024
1 parent 9b11f9b commit 51f8ebe
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 21 deletions.
48 changes: 35 additions & 13 deletions janus_core/geom_opt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Geometry optimisation."""
"""Geometry optimization."""

from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Optional

from ase import Atoms
from ase.io import write
from ase.io import read, write

try:
from ase.filters import FrechetCellFilter as DefaultFilter
Expand All @@ -22,8 +22,8 @@ def optimize(
filter_kwargs: Optional[dict[str, Any]] = None,
optimizer: callable = LBFGS,
opt_kwargs: Optional[dict[str, Any]] = None,
save_path: Optional[Union[Path, str]] = None,
save_kwargs: Optional[dict[str, Any]] = None,
struct_kwargs: Optional[dict[str, Any]] = None,
traj_kwargs: Optional[dict[str, Any]] = None,
) -> Atoms:
"""Optimize geometry of input structure.
Expand All @@ -36,17 +36,20 @@ def optimize(
dyn_kwargs : Optional[dict[str, Any]]
kwargs to pass to dyn.run. Default is None.
filter_func : Optional[callable]
Apply constraints to atoms through ASE filter function. Default is `FrechetCellFilter`.
Apply constraints to atoms through ASE filter function.
Default is `FrechetCellFilter`.
filter_kwargs : Optional[dict[str, Any]]
kwargs to pass to filter_func. Default is None.
optimzer : callable
ASE optimization function. Default is `LBFGS`.
opt_kwargs : Optional[dict[str, Any]]
kwargs to pass to optimzer. Default is None.
save_path : Optional[Union[Path, str]]
Path to save optimised structure. Default is None.
save_kwargs : Optional[dict[str, Any]]
kwargs to pass to ase.io.write. Default is None.
struct_kwargs : Optional[dict[str, Any]]
kwargs to pass to ase.io.write to save optimized structure.
Must include "filename" keyword. Default is None.
traj_kwargs : Optional[dict[str, Any]]
kwargs to pass to ase.io.write to save optimization trajectory.
Must include "filename" keyword. Default is None.
Returns
-------
Expand All @@ -56,7 +59,19 @@ def optimize(
dyn_kwargs = dyn_kwargs if dyn_kwargs else {}
filter_kwargs = filter_kwargs if filter_kwargs else {}
opt_kwargs = opt_kwargs if opt_kwargs else {}
save_kwargs = save_kwargs if save_kwargs else {}
struct_kwargs = struct_kwargs if struct_kwargs else {}
traj_kwargs = traj_kwargs if traj_kwargs else {}

if struct_kwargs and "filename" not in struct_kwargs:
raise ValueError("'filename' must be included in struct_kwargs")

if traj_kwargs and "filename" not in traj_kwargs:
raise ValueError("'filename' must be included in traj_kwargs")

if traj_kwargs and "trajectory" not in opt_kwargs:
raise ValueError(
"'trajectory' must be a key in opt_kwargs to save the trajectory."
)

if filter_func is not None:
filtered_atoms = filter_func(atoms, **filter_kwargs)
Expand All @@ -66,7 +81,14 @@ def optimize(

dyn.run(fmax=fmax, **dyn_kwargs)

if save_path is not None:
write(save_path, atoms, **save_kwargs)
# Write out optimized structure
if struct_kwargs:
write(images=atoms, **struct_kwargs)

# Reformat trajectory file from binary
if traj_kwargs:
traj = read(opt_kwargs["trajectory"], index=":")
write(images=traj, **traj_kwargs)
Path(opt_kwargs["trajectory"]).unlink()

return atoms
48 changes: 40 additions & 8 deletions tests/test_geom_opt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Test geometry optimisation."""
"""Test geometry optimization."""

from pathlib import Path

Expand Down Expand Up @@ -68,9 +68,10 @@ def test_optimize(architecture, structure, model_path, expected, kwargs):
assert atoms.get_potential_energy() == pytest.approx(expected)


def test_save(tmp_path):
"""Test saving optimised structure."""
def test_saving_struct(tmp_path):
"""Test saving optimized structure."""
data_path = DATA_PATH / "NaCl.cif"
struct_path = tmp_path / "NaCl.xyz"
single_point = SinglePoint(
system=data_path, architecture="mace", model_paths=MODEL_PATH
)
Expand All @@ -79,20 +80,51 @@ def test_save(tmp_path):

optimize(
single_point.sys,
save_path=tmp_path / "NaCl.xyz",
save_kwargs={"format": "extxyz"},
struct_kwargs={"filename": struct_path, "format": "extxyz"},
)
opt_struct = read(tmp_path / "NaCl.xyz")
opt_struct = read(struct_path)

assert opt_struct.get_potential_energy() < init_energy


def test_traj(tmp_path):
"""Test saving optimisation trajectory output."""
def test_saving_traj(tmp_path):
"""Test saving optimization trajectory output."""
data_path = DATA_PATH / "NaCl.cif"
single_point = SinglePoint(
system=data_path, architecture="mace", model_paths=MODEL_PATH
)
optimize(single_point.sys, opt_kwargs={"trajectory": str(tmp_path / "NaCl.traj")})
traj = read(tmp_path / "NaCl.traj", index=":")
assert len(traj) == 3


def test_traj_reformat(tmp_path):
"""Test saving optimization trajectory in different format."""
data_path = DATA_PATH / "NaCl.cif"
traj_path_binary = tmp_path / "NaCl.traj"
traj_path_xyz = tmp_path / "NaCl-traj.xyz"

single_point = SinglePoint(
system=data_path, architecture="mace", model_paths=MODEL_PATH
)

optimize(
single_point.sys,
opt_kwargs={"trajectory": str(traj_path_binary)},
traj_kwargs={"filename": traj_path_xyz},
)
traj = read(tmp_path / "NaCl-traj.xyz", index=":")

assert len(traj) == 3
assert traj_path_binary.is_file() is False


def test_missing_traj_kwarg(tmp_path):
"""Test saving optimization trajectory in different format."""
data_path = DATA_PATH / "NaCl.cif"
traj_path = tmp_path / "NaCl-traj.xyz"
single_point = SinglePoint(
system=data_path, architecture="mace", model_paths=MODEL_PATH
)
with pytest.raises(ValueError):
optimize(single_point.sys, traj_kwargs={"filename": traj_path})

0 comments on commit 51f8ebe

Please sign in to comment.