From 8cb1d3b99105a5e8e9230a46cf2698abcb7821a1 Mon Sep 17 00:00:00 2001 From: Yuan Chiang Date: Tue, 14 Jan 2025 13:26:18 -0800 Subject: [PATCH] propagate device setting to calculator --- mlip_arena/tasks/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlip_arena/tasks/utils.py b/mlip_arena/tasks/utils.py index 71e27c6..da920f9 100644 --- a/mlip_arena/tasks/utils.py +++ b/mlip_arena/tasks/utils.py @@ -28,11 +28,13 @@ def get_calculator( device: str | None = None, ) -> Calculator | SumCalculator: """Get a calculator with optional dispersion correction.""" - device = device or str(get_freer_device()) - logger.info(f"Using device: {device}") + device = device or str(get_freer_device()) calculator_kwargs = calculator_kwargs or {} + calculator_kwargs.update({"device": device}) + + logger.info(f"Using device: {device}") if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum: calc = calculator_name.value(**calculator_kwargs)