diff --git a/janus_core/mlip_calculators.py b/janus_core/mlip_calculators.py index 2ef4b3b1..f6a3fbff 100644 --- a/janus_core/mlip_calculators.py +++ b/janus_core/mlip_calculators.py @@ -58,7 +58,7 @@ def choose_calculator( from mace.calculators import mace_mp kwargs.setdefault("default_dtype", "float64") - kwargs.setdefault("model", "small") + kwargs["model"] = kwargs.pop("model_paths", "small") calculator = mace_mp(**kwargs) elif architecture == "mace_off": @@ -66,7 +66,7 @@ def choose_calculator( from mace.calculators import mace_off kwargs.setdefault("default_dtype", "float64") - kwargs.setdefault("model", "small") + kwargs["model"] = kwargs.pop("model_paths", "small") calculator = mace_off(**kwargs) elif architecture == "m3gnet": diff --git a/tests/test_mlip_calculators.py b/tests/test_mlip_calculators.py index d07e0f4c..cc06c5bc 100644 --- a/tests/test_mlip_calculators.py +++ b/tests/test_mlip_calculators.py @@ -14,6 +14,16 @@ ), ("mace_off", "cpu", {}), ("mace_mp", "cpu", {}), + ( + "mace_mp", + "cpu", + {"model_paths": Path(__file__).parent / "models" / "mace_mp_small.model"}, + ), + ( + "mace_off", + "cpu", + {"model_paths": "small"}, + ), ] test_data_extras = [("m3gnet", "cpu"), ("chgnet", "")]