diff --git a/tests/test_md.py b/tests/test_md.py index 18aa1fdd..a9dba54a 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -18,6 +18,15 @@ DATA_PATH = Path(__file__).parent / "data" MODEL_PATH = Path(__file__).parent / "models" / "mace_mp_small.model" +try: + from ase.md.bussi import Bussi # noqa: F401 + + from janus_core.calculations.md import NVT_Bussi + + ASE_IMPORT_ERROR = False +except ImportError: + ASE_IMPORT_ERROR = True + test_data = [ (NVT, "nvt"), (NVE, "nve"), @@ -226,6 +235,57 @@ def test_nph(): stats_path.unlink(missing_ok=True) +@pytest.mark.skipif(ASE_IMPORT_ERROR, reason="Requires updated version of ASE") +def test_nvt_bussi(): + """Test Bussi molecular dynamics.""" + restart_path_1 = Path("NaCl-nvt-bussi-T300.0-res-2.extxyz") + restart_path_2 = Path("NaCl-nvt-bussi-T300.0-res-4.extxyz") + restart_final = Path("NaCl-nvt-bussi-T300.0-final.extxyz") + traj_path = Path("NaCl-nvt-bussi-T300.0-traj.extxyz") + stats_path = Path("NaCl-nvt-bussi-T300.0-stats.dat") + + assert not restart_path_1.exists() + assert not restart_path_2.exists() + assert not restart_final.exists() + assert not traj_path.exists() + assert not stats_path.exists() + + bussi = NVT_Bussi( + struct_path=DATA_PATH / "NaCl.cif", + arch="mace", + model_path=MODEL_PATH, + temp=300.0, + steps=4, + traj_every=1, + restart_every=2, + stats_every=1, + taut=10, + ) + + try: + bussi.run() + restart_atoms_1 = read(restart_path_1) + assert isinstance(restart_atoms_1, Atoms) + restart_atoms_2 = read(restart_path_2) + assert isinstance(restart_atoms_2, Atoms) + restart_atoms_final = read(restart_final) + assert isinstance(restart_atoms_final, Atoms) + traj = read(traj_path, index=":") + assert all(isinstance(image, Atoms) for image in traj) + assert len(traj) == 5 + + with open(stats_path, encoding="utf8") as stats_file: + lines = stats_file.readlines() + assert "Target_T [K]" in lines[0] + assert len(lines) == 6 + finally: + restart_path_1.unlink(missing_ok=True) + restart_path_2.unlink(missing_ok=True) + restart_final.unlink(missing_ok=True) + traj_path.unlink(missing_ok=True) + stats_path.unlink(missing_ok=True) + + def test_restart(tmp_path): """Test restarting molecular dynamics simulation.""" file_prefix = tmp_path / "Cl4Na4-nvt-T300.0"