From ba0f18aafa671a3b253786d2b4a9c5b96a2f74c9 Mon Sep 17 00:00:00 2001 From: alin elena Date: Wed, 4 Sep 2024 11:18:02 +0100 Subject: [PATCH 1/4] add orb support... tricky --- janus_core/helpers/janus_types.py | 2 +- janus_core/helpers/mlip_calculators.py | 30 ++++++++++++++++++++++++++ pyproject.toml | 8 +++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index 27c6c556..0f8c878b 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -115,7 +115,7 @@ class CorrelationKwargs(TypedDict, total=True): # Janus specific Architectures = Literal[ - "mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet" + "mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet", "orb" ] Devices = Literal["cpu", "cuda", "mps", "xpu"] Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"] diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index 537dc027..6780ee4e 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -225,6 +225,36 @@ def choose_calculator( kwargs.setdefault("sevennet_config", None) calculator = SevenNetCalculator(model=model_path, device=device, **kwargs) + elif arch == "orb": + __version__ = "0.3" + from orb_models.forcefield.calculator import ORBCalculator + from orb_models.forcefield.graph_regressor import GraphRegressor + import orb_models.forcefield.pretrained as orb_ff + + if isinstance(model_path, str): + match model_path: + case "orb-v1": + model = orb_ff.orb_v1() + case "orb-mptraj-only-v1": + model = orb_ff.orb_v1_mptraj_only() + case "orb-d3-v1": + model = orb_ff.orb_d3_v1() + case "orb-d3-xs-v1": + model = orb_ff.orb_d3_xs_v1() + case "orb-d3-sm-v1": + model = orb_ff.orb_d3_sm_v1() + case _: + raise ValueError( + "Please specify `model_path`, as there is no " + f"default model for {arch}" + ) + elif isinstance(model_path, GraphRegressor): + model = model_path + else: + model = orb_ff.orb_v1_mptraj_only() + + calculator = ORBCalculator(model=model, device=device, **kwargs) + else: raise ValueError( f"Unrecognized {arch=}. Suported architectures " diff --git a/pyproject.toml b/pyproject.toml index a481a7d9..2097bb14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,10 @@ m3gnet = [ "matgl == 1.1.3", "dgl == 2.1.0", ] +orb = [ + "orb-models == 0.4.1", + "pynanoflann", +] sevennet = [ "sevenn == 0.10.0", ] @@ -55,6 +59,7 @@ all = [ "janus-core[alignn]", "janus-core[chgnet]", "janus-core[m3gnet]", + "janus-core[orb]", "janus-core[sevennet]", ] @@ -164,3 +169,6 @@ default-groups = [ "docs", "pre-commit", ] + +[tool.uv.sources] +pynanoflann = { git = "https://github.com/dwastberg/pynanoflann", rev = "af434039ae14bedcbb838a7808924d6689274168" } From 061e7ed5f41e204d15f66cd4de9dcae1eb255a52 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Thu, 2 Jan 2025 15:51:37 +0000 Subject: [PATCH 2/4] Generalise orb version --- janus_core/helpers/mlip_calculators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index 6780ee4e..a1310c2c 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -226,7 +226,7 @@ def choose_calculator( calculator = SevenNetCalculator(model=model_path, device=device, **kwargs) elif arch == "orb": - __version__ = "0.3" + from orb_models import __version__ from orb_models.forcefield.calculator import ORBCalculator from orb_models.forcefield.graph_regressor import GraphRegressor import orb_models.forcefield.pretrained as orb_ff From 14fab378bd0722676dab427a621aa08c3013f6a5 Mon Sep 17 00:00:00 2001 From: alin m elena Date: Fri, 14 Feb 2025 13:41:04 +0000 Subject: [PATCH 3/4] calculators --- janus_core/helpers/mlip_calculators.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index 903fe683..a41a857c 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -232,16 +232,16 @@ def choose_calculator( if isinstance(model_path, str): match model_path: - case "orb-v1": - model = orb_ff.orb_v1() - case "orb-mptraj-only-v1": - model = orb_ff.orb_v1_mptraj_only() - case "orb-d3-v1": - model = orb_ff.orb_d3_v1() - case "orb-d3-xs-v1": - model = orb_ff.orb_d3_xs_v1() - case "orb-d3-sm-v1": - model = orb_ff.orb_d3_sm_v1() + case "orb-v2": + model = orb_ff.orb_v2() + case "orb-mptraj-only-v2": + model = orb_ff.orb_v2_mptraj_only() + case "orb-d3-v2": + model = orb_ff.orb_d3_v2() + case "orb-d3-xs-v2": + model = orb_ff.orb_d3_xs_v2() + case "orb-d3-sm-v2": + model = orb_ff.orb_d3_sm_v2() case _: raise ValueError( "Please specify `model_path`, as there is no " From 6f9d83666c5b3a24f407c6e119acd43d9e6a5780 Mon Sep 17 00:00:00 2001 From: alin m elena Date: Fri, 14 Feb 2025 13:51:48 +0000 Subject: [PATCH 4/4] add tests --- janus_core/helpers/mlip_calculators.py | 2 +- tests/test_single_point.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index a41a857c..cd787e84 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -250,7 +250,7 @@ def choose_calculator( elif isinstance(model_path, GraphRegressor): model = model_path else: - model = orb_ff.orb_v1_mptraj_only() + model = orb_ff.orb_v2() calculator = ORBCalculator(model=model, device=device, **kwargs) diff --git a/tests/test_single_point.py b/tests/test_single_point.py index 1b86c3c6..a76552fd 100644 --- a/tests/test_single_point.py +++ b/tests/test_single_point.py @@ -284,6 +284,8 @@ def test_mlips(arch, device, expected_energy): ("sevennet", "cpu", -27.061979293823242, {"model_path": SEVENNET_PATH}), ("sevennet", "cpu", -27.061979293823242, {}), ("sevennet", "cpu", -27.061979293823242, {"model_path": "SevenNet-0_11July2024"}), + ("orb", "cpu", -27.088973999023438, {}), + ("orb", "cpu", -27.088973999023438, {"model_path": "orb-v2"}), ]