From 51f8ebe3ded4cbc46d7281d9368bbf1cc0455756 Mon Sep 17 00:00:00 2001 From: ElliottKasoar Date: Thu, 22 Feb 2024 18:17:48 +0000 Subject: [PATCH] Add option to save optimization trajectory --- janus_core/geom_opt.py | 48 ++++++++++++++++++++++++++++++------------ tests/test_geom_opt.py | 48 +++++++++++++++++++++++++++++++++++------- 2 files changed, 75 insertions(+), 21 deletions(-) diff --git a/janus_core/geom_opt.py b/janus_core/geom_opt.py index c55797e4..9753f6b2 100644 --- a/janus_core/geom_opt.py +++ b/janus_core/geom_opt.py @@ -1,10 +1,10 @@ -"""Geometry optimisation.""" +"""Geometry optimization.""" from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional from ase import Atoms -from ase.io import write +from ase.io import read, write try: from ase.filters import FrechetCellFilter as DefaultFilter @@ -22,8 +22,8 @@ def optimize( filter_kwargs: Optional[dict[str, Any]] = None, optimizer: callable = LBFGS, opt_kwargs: Optional[dict[str, Any]] = None, - save_path: Optional[Union[Path, str]] = None, - save_kwargs: Optional[dict[str, Any]] = None, + struct_kwargs: Optional[dict[str, Any]] = None, + traj_kwargs: Optional[dict[str, Any]] = None, ) -> Atoms: """Optimize geometry of input structure. @@ -36,17 +36,20 @@ def optimize( dyn_kwargs : Optional[dict[str, Any]] kwargs to pass to dyn.run. Default is None. filter_func : Optional[callable] - Apply constraints to atoms through ASE filter function. Default is `FrechetCellFilter`. + Apply constraints to atoms through ASE filter function. + Default is `FrechetCellFilter`. filter_kwargs : Optional[dict[str, Any]] kwargs to pass to filter_func. Default is None. optimzer : callable ASE optimization function. Default is `LBFGS`. opt_kwargs : Optional[dict[str, Any]] kwargs to pass to optimzer. Default is None. - save_path : Optional[Union[Path, str]] - Path to save optimised structure. Default is None. - save_kwargs : Optional[dict[str, Any]] - kwargs to pass to ase.io.write. Default is None. + struct_kwargs : Optional[dict[str, Any]] + kwargs to pass to ase.io.write to save optimized structure. + Must include "filename" keyword. Default is None. + traj_kwargs : Optional[dict[str, Any]] + kwargs to pass to ase.io.write to save optimization trajectory. + Must include "filename" keyword. Default is None. Returns ------- @@ -56,7 +59,19 @@ def optimize( dyn_kwargs = dyn_kwargs if dyn_kwargs else {} filter_kwargs = filter_kwargs if filter_kwargs else {} opt_kwargs = opt_kwargs if opt_kwargs else {} - save_kwargs = save_kwargs if save_kwargs else {} + struct_kwargs = struct_kwargs if struct_kwargs else {} + traj_kwargs = traj_kwargs if traj_kwargs else {} + + if struct_kwargs and "filename" not in struct_kwargs: + raise ValueError("'filename' must be included in struct_kwargs") + + if traj_kwargs and "filename" not in traj_kwargs: + raise ValueError("'filename' must be included in traj_kwargs") + + if traj_kwargs and "trajectory" not in opt_kwargs: + raise ValueError( + "'trajectory' must be a key in opt_kwargs to save the trajectory." + ) if filter_func is not None: filtered_atoms = filter_func(atoms, **filter_kwargs) @@ -66,7 +81,14 @@ def optimize( dyn.run(fmax=fmax, **dyn_kwargs) - if save_path is not None: - write(save_path, atoms, **save_kwargs) + # Write out optimized structure + if struct_kwargs: + write(images=atoms, **struct_kwargs) + + # Reformat trajectory file from binary + if traj_kwargs: + traj = read(opt_kwargs["trajectory"], index=":") + write(images=traj, **traj_kwargs) + Path(opt_kwargs["trajectory"]).unlink() return atoms diff --git a/tests/test_geom_opt.py b/tests/test_geom_opt.py index f6e37b4b..f928e343 100644 --- a/tests/test_geom_opt.py +++ b/tests/test_geom_opt.py @@ -1,4 +1,4 @@ -"""Test geometry optimisation.""" +"""Test geometry optimization.""" from pathlib import Path @@ -68,9 +68,10 @@ def test_optimize(architecture, structure, model_path, expected, kwargs): assert atoms.get_potential_energy() == pytest.approx(expected) -def test_save(tmp_path): - """Test saving optimised structure.""" +def test_saving_struct(tmp_path): + """Test saving optimized structure.""" data_path = DATA_PATH / "NaCl.cif" + struct_path = tmp_path / "NaCl.xyz" single_point = SinglePoint( system=data_path, architecture="mace", model_paths=MODEL_PATH ) @@ -79,16 +80,15 @@ def test_save(tmp_path): optimize( single_point.sys, - save_path=tmp_path / "NaCl.xyz", - save_kwargs={"format": "extxyz"}, + struct_kwargs={"filename": struct_path, "format": "extxyz"}, ) - opt_struct = read(tmp_path / "NaCl.xyz") + opt_struct = read(struct_path) assert opt_struct.get_potential_energy() < init_energy -def test_traj(tmp_path): - """Test saving optimisation trajectory output.""" +def test_saving_traj(tmp_path): + """Test saving optimization trajectory output.""" data_path = DATA_PATH / "NaCl.cif" single_point = SinglePoint( system=data_path, architecture="mace", model_paths=MODEL_PATH @@ -96,3 +96,35 @@ def test_traj(tmp_path): optimize(single_point.sys, opt_kwargs={"trajectory": str(tmp_path / "NaCl.traj")}) traj = read(tmp_path / "NaCl.traj", index=":") assert len(traj) == 3 + + +def test_traj_reformat(tmp_path): + """Test saving optimization trajectory in different format.""" + data_path = DATA_PATH / "NaCl.cif" + traj_path_binary = tmp_path / "NaCl.traj" + traj_path_xyz = tmp_path / "NaCl-traj.xyz" + + single_point = SinglePoint( + system=data_path, architecture="mace", model_paths=MODEL_PATH + ) + + optimize( + single_point.sys, + opt_kwargs={"trajectory": str(traj_path_binary)}, + traj_kwargs={"filename": traj_path_xyz}, + ) + traj = read(tmp_path / "NaCl-traj.xyz", index=":") + + assert len(traj) == 3 + assert traj_path_binary.is_file() is False + + +def test_missing_traj_kwarg(tmp_path): + """Test saving optimization trajectory in different format.""" + data_path = DATA_PATH / "NaCl.cif" + traj_path = tmp_path / "NaCl-traj.xyz" + single_point = SinglePoint( + system=data_path, architecture="mace", model_paths=MODEL_PATH + ) + with pytest.raises(ValueError): + optimize(single_point.sys, traj_kwargs={"filename": traj_path})