Skip to content

Commit

Permalink
add mattersim calculator fixes stfc#425
Browse files Browse the repository at this point in the history
  • Loading branch information
alinelena authored and ElliottKasoar committed Feb 25, 2025
1 parent 09dff33 commit e9646ff
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class CorrelationKwargs(TypedDict, total=True):
"nequip",
"dpa3",
"orb",
"mattersim",
]
Devices = Literal["cpu", "cuda", "mps", "xpu"]
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"]
Expand Down
11 changes: 11 additions & 0 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,17 @@ def choose_calculator(

calculator = ORBCalculator(model=model, device=device, **kwargs)

elif arch == "mattersim":
from mattersim import __version__
from mattersim.forcefield import MatterSimCalculator

if isinstance(model_path, Path):
model_path = str(model_path)
elif not isinstance(model_path, str):
model_path = "mattersim-v1.0.0-5M"

calculator = MatterSimCalculator(load_path=model_path, device=device, **kwargs)

else:
raise ValueError(
f"Unrecognized {arch=}. Suported architectures "
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,17 @@ orb = [
sevennet = [
"sevenn == 0.10.3",
]
mattersim = [
"mattersim == 1.1.1",
]
all = [
"janus-core[chgnet]",
"janus-core[dpa3]",
"janus-core[mace]",
"janus-core[nequip]",
"janus-core[orb]",
"janus-core[sevennet]",
"janus-core[mattersim]",
]

# MLIPs with dgl dependency
Expand Down
1 change: 1 addition & 0 deletions tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
("chgnet", "cpu", {"model": CHGNET_MODEL}),
("dpa3", "cpu", {"model_path": DPA3_PATH}),
("dpa3", "cpu", {"model": DPA3_PATH}),
("mattersim", "cpu", {}),
("nequip", "cpu", {"model_path": NEQUIP_PATH}),
("nequip", "cpu", {"model": NEQUIP_PATH}),
("orb", "cpu", {}),
Expand Down
1 change: 1 addition & 0 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_potential_energy(
[
("chgnet", "cpu", -29.331436157226562, "NaCl.cif", {}),
("dpa3", "cpu", -27.053507387638092, "NaCl.cif", {"model_path": DPA3_PATH}),
("mattersim", "cpu", -27.06208038330078, "NaCl.cif", {}),
(
"nequip",
"cpu",
Expand Down

0 comments on commit e9646ff

Please sign in to comment.