Skip to content

Commit

Permalink
Add MD forces by default (#367)
Browse files Browse the repository at this point in the history
* Test MD output

* Output MD forces

* Test invalidate calc for MD

* Fix invalidating calc for MD

* Fix arch consistency

* Fix results output
  • Loading branch information
ElliottKasoar authored Dec 11, 2024
1 parent 7c7e317 commit e2b79f6
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
11 changes: 11 additions & 0 deletions janus_core/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

from ase import Atoms
Expand Down Expand Up @@ -165,6 +166,16 @@ def __init__(
set_calc=set_calc,
)

# Set architecture to match calculator architecture
if isinstance(self.struct, Sequence):
if all(
image.calc and "arch" in image.calc.parameters for image in self.struct
):
self.arch = self.struct[0].calc.parameters["arch"]
else:
if self.struct.calc and "arch" in self.struct.calc.parameters:
self.arch = self.struct.calc.parameters["arch"]

FileNameMixin.__init__(
self,
self.struct,
Expand Down
16 changes: 12 additions & 4 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,6 @@ def __init__(
if self.ramp_temp and (self.temp_start < 0 or self.temp_end < 0):
raise ValueError("Start and end temperatures must be positive")

self.write_kwargs.setdefault(
"columns", ["symbols", "positions", "momenta", "masses"]
)

# Read last image by default
read_kwargs.setdefault("index", -1)

Expand Down Expand Up @@ -473,6 +469,15 @@ def __init__(
# If not specified otherwise, save optimized structure consistently with others
opt_file = self._build_filename("opt.extxyz", self.param_prefix, filename=None)

# Set defaults
default_columns = ["symbols", "positions", "momenta", "masses"]
if not write_kwargs.get("invalidate_calc", False):
default_columns.append("forces")
if "arch" in self.struct.calc.parameters and write_kwargs.get("set_info", True):
default_columns.append(f"{self.arch}_forces")

self.write_kwargs.setdefault("columns", default_columns)

if "write_kwargs" in self.minimize_kwargs:
# Use _build_filename even if given filename to ensure directory exists
self.minimize_kwargs["write_kwargs"].setdefault("filename", None)
Expand Down Expand Up @@ -1079,6 +1084,9 @@ def _run_dynamics(self) -> None:
self.temp = temp
self._set_velocity_distribution()
if isclose(temp, 0.0):
# Calculate forces and energies to be output
self.struct.get_potential_energy()
self.struct.get_forces()
self._write_final_state()
self.created_final_file = True
continue
Expand Down
18 changes: 11 additions & 7 deletions tests/test_md_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def test_md(ensemble):
# Check at least one image has been saved in trajectory
atoms = read(traj_path)
assert isinstance(atoms, Atoms)
assert "energy" in atoms.calc.results
assert "mace_mp_energy" in atoms.info
assert "forces" in atoms.calc.results
assert "mace_mp_forces" in atoms.arrays
assert "momenta" in atoms.arrays
assert "masses" in atoms.arrays

finally:
final_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -477,7 +483,7 @@ def test_write_kwargs(tmp_path):
traj_path = tmp_path / "md-traj.extxyz"

write_kwargs = (
"{'invalidate_calc': False, 'columns': ['symbols', 'positions', 'masses']}"
"{'invalidate_calc': True, 'columns': ['symbols', 'positions', 'masses']}"
)

result = runner.invoke(
Expand Down Expand Up @@ -509,13 +515,11 @@ def test_write_kwargs(tmp_path):
assert not final_atoms.has("momenta")
assert not traj[0].has("momenta")

# Check calculated results have been saved
assert "energy" in final_atoms.calc.results
assert "energy" in traj[0].calc.results

# Check labelled info has been set
assert "mace_mp_energy" in final_atoms.info
# Check results saved with arch label, but calc is not attached
assert final_atoms.calc is None
assert traj[0].calc is None
assert "mace_mp_energy" in traj[0].info
assert "mace_mp_energy" in final_atoms.info

assert "system_name" in final_atoms.info
assert final_atoms.info["system_name"] == "NaCl"
Expand Down

0 comments on commit e2b79f6

Please sign in to comment.