diff --git a/mlip_arena/tasks/utils.py b/mlip_arena/tasks/utils.py index 71e27c6..da920f9 100644 --- a/mlip_arena/tasks/utils.py +++ b/mlip_arena/tasks/utils.py @@ -28,11 +28,13 @@ def get_calculator( device: str | None = None, ) -> Calculator | SumCalculator: """Get a calculator with optional dispersion correction.""" - device = device or str(get_freer_device()) - logger.info(f"Using device: {device}") + device = device or str(get_freer_device()) calculator_kwargs = calculator_kwargs or {} + calculator_kwargs.update({"device": device}) + + logger.info(f"Using device: {device}") if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum: calc = calculator_name.value(**calculator_kwargs)