diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index aa92a61f..3fec759a 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -125,6 +125,7 @@ class CorrelationKwargs(TypedDict, total=True): "nequip", "dpa3", "orb", + "mattersim", ] Devices = Literal["cpu", "cuda", "mps", "xpu"] Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"] diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index 8b67e8e9..2fff9a2b 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -284,6 +284,17 @@ def choose_calculator( calculator = ORBCalculator(model=model, device=device, **kwargs) + elif arch == "mattersim": + from mattersim import __version__ + from mattersim.forcefield import MatterSimCalculator + + if isinstance(model_path, Path): + model_path = str(model_path) + elif not isinstance(model_path, str): + model_path = "mattersim-v1.0.0-5M" + + calculator = MatterSimCalculator(load_path=model_path, device=device, **kwargs) + else: raise ValueError( f"Unrecognized {arch=}. Suported architectures " diff --git a/pyproject.toml b/pyproject.toml index 8b2f208f..b2f4cf4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,9 @@ orb = [ sevennet = [ "sevenn == 0.10.3", ] +mattersim = [ + "mattersim == 1.1.1", +] all = [ "janus-core[chgnet]", "janus-core[dpa3]", @@ -67,6 +70,7 @@ all = [ "janus-core[nequip]", "janus-core[orb]", "janus-core[sevennet]", + "janus-core[mattersim]", ] # MLIPs with dgl dependency diff --git a/tests/test_mlip_calculators.py b/tests/test_mlip_calculators.py index 693efde2..6b44a55f 100644 --- a/tests/test_mlip_calculators.py +++ b/tests/test_mlip_calculators.py @@ -74,6 +74,7 @@ ("chgnet", "cpu", {"model": CHGNET_MODEL}), ("dpa3", "cpu", {"model_path": DPA3_PATH}), ("dpa3", "cpu", {"model": DPA3_PATH}), + ("mattersim", "cpu", {}), ("nequip", "cpu", {"model_path": NEQUIP_PATH}), ("nequip", "cpu", {"model": NEQUIP_PATH}), ("orb", "cpu", {}), diff --git a/tests/test_single_point.py b/tests/test_single_point.py index cca0a0c7..975a5c18 100644 --- a/tests/test_single_point.py +++ b/tests/test_single_point.py @@ -70,6 +70,7 @@ def test_potential_energy( [ ("chgnet", "cpu", -29.331436157226562, "NaCl.cif", {}), ("dpa3", "cpu", -27.053507387638092, "NaCl.cif", {"model_path": DPA3_PATH}), + ("mattersim", "cpu", -27.06208038330078, "NaCl.cif", {}), ( "nequip", "cpu",