diff --git a/mlip_arena/models/externals/mattersim.py b/mlip_arena/models/externals/mattersim.py index f1a2d52..dab0660 100644 --- a/mlip_arena/models/externals/mattersim.py +++ b/mlip_arena/models/externals/mattersim.py @@ -24,13 +24,21 @@ def __init__( load_path=checkpoint, device=str(device or get_freer_device()), **kwargs ) - def calculate( - self, - atoms: Atoms | None = None, - properties: list | None = None, - system_changes: list | None = None, - ): - super().calculate(atoms, properties, system_changes) + def __getstate__(self): + state = self.__dict__.copy() + + # BUG: remove unpicklizable potential + state.pop("potential", None) + + return state + + # def calculate( + # self, + # atoms: Atoms | None = None, + # properties: list | None = None, + # system_changes: list | None = None, + # ): + # super().calculate(atoms, properties, system_changes) # # convert unpicklizable atoms back to picklizable atoms to avoid prefect pickling error # if isinstance(self.atoms, MSONAtoms): diff --git a/mlip_arena/tasks/__init__.py b/mlip_arena/tasks/__init__.py index 733c309..0eea6b2 100644 --- a/mlip_arena/tasks/__init__.py +++ b/mlip_arena/tasks/__init__.py @@ -3,8 +3,8 @@ import yaml from huggingface_hub import HfApi, HfFileSystem, hf_hub_download -from mlip_arena.models import MLIP -from mlip_arena.models import REGISTRY as MODEL_REGISTRY +# from mlip_arena.models import MLIP +# from mlip_arena.models import REGISTRY as MODEL_REGISTRY try: from .elasticity import run as ELASTICITY @@ -13,8 +13,9 @@ from .neb import run as NEB from .neb import run_from_endpoints as NEB_FROM_ENDPOINTS from .optimize import run as OPT + from .phonon import run as PHONON - __all__ = ["OPT", "EOS", "MD", "NEB", "NEB_FROM_ENDPOINTS", "ELASTICITY"] + __all__ = ["OPT", "EOS", "MD", "NEB", "NEB_FROM_ENDPOINTS", "ELASTICITY", "PHONON"] except ImportError: pass @@ -22,43 +23,43 @@ REGISTRY = yaml.safe_load(f) -class Task: - def __init__(self): - self.name: str = self.__class__.__name__ # display name on the leaderboard +# class Task: +# def __init__(self): +# self.name: str = self.__class__.__name__ # display name on the leaderboard - def run_local(self, model: MLIP): - """Run the task using the given model and return the results.""" - raise NotImplementedError +# def run_local(self, model: MLIP): +# """Run the task using the given model and return the results.""" +# raise NotImplementedError - def run_hf(self, model: MLIP): - """Run the task using the given model and return the results.""" - raise NotImplementedError +# def run_hf(self, model: MLIP): +# """Run the task using the given model and return the results.""" +# raise NotImplementedError - # Calcualte evaluation metrics and postprocessed data - api = HfApi() - api.upload_file( - path_or_fileobj="results.json", - path_in_repo=f"{self.__class__.__name__}/{model.__class__.__name__}/results.json", # Upload to a specific folder - repo_id="atomind/mlip-arena", - repo_type="dataset", - ) +# # Calcualte evaluation metrics and postprocessed data +# api = HfApi() +# api.upload_file( +# path_or_fileobj="results.json", +# path_in_repo=f"{self.__class__.__name__}/{model.__class__.__name__}/results.json", # Upload to a specific folder +# repo_id="atomind/mlip-arena", +# repo_type="dataset", +# ) - def run_nersc(self, model: MLIP): - """Run the task using the given model and return the results.""" - raise NotImplementedError +# def run_nersc(self, model: MLIP): +# """Run the task using the given model and return the results.""" +# raise NotImplementedError - def get_results(self): - """Get the results from the task.""" - # fs = HfFileSystem() - # files = fs.glob(f"datasets/atomind/mlip-arena/{self.__class__.__name__}/*/*.json") +# def get_results(self): +# """Get the results from the task.""" +# # fs = HfFileSystem() +# # files = fs.glob(f"datasets/atomind/mlip-arena/{self.__class__.__name__}/*/*.json") - for model, metadata in MODEL_REGISTRY.items(): - results = hf_hub_download( - repo_id="atomind/mlip-arena", - filename="results.json", - subfolder=f"{self.__class__.__name__}/{model}", - repo_type="dataset", - revision=None, - ) +# for model, metadata in MODEL_REGISTRY.items(): +# results = hf_hub_download( +# repo_id="atomind/mlip-arena", +# filename="results.json", +# subfolder=f"{self.__class__.__name__}/{model}", +# repo_type="dataset", +# revision=None, +# ) - return results +# return results diff --git a/mlip_arena/tasks/phonon.py b/mlip_arena/tasks/phonon.py new file mode 100644 index 0000000..57ec3cf --- /dev/null +++ b/mlip_arena/tasks/phonon.py @@ -0,0 +1,162 @@ +""" +This module has been adapted from Quacc (https://github.com/Quantum-Accelerators/quacc). By using this software, you agree to the Quacc license agreement: https://github.com/Quantum-Accelerators/quacc/blob/main/LICENSE.md + + +BSD 3-Clause License + +Copyright (c) 2025, Andrew S. Rosen. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +- Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from pathlib import Path + +import numpy as np +from phonopy import Phonopy +from phonopy.structure.atoms import PhonopyAtoms +from prefect import task +from prefect.cache_policies import INPUTS, TASK_SOURCE +from prefect.runtime import task_run + +from ase import Atoms +from ase.calculators.calculator import BaseCalculator + + +@task(cache_policy=TASK_SOURCE + INPUTS) +def get_phonopy( + atoms: Atoms, + supercell_matrix: list[int] | None = None, + min_lengths: float | tuple[float, float, float] | None = None, + symprec: float = 1e-5, + distance: float = 0.01, + phonopy_kwargs: dict = {}, +) -> Phonopy: + if supercell_matrix is None and min_lengths is not None: + supercell_matrix = np.diag( + np.round(np.ceil(min_lengths / atoms.cell.lengths())) + ) + + phonon = Phonopy( + PhonopyAtoms( + symbols=atoms.get_chemical_symbols(), + cell=atoms.get_cell(), + scaled_positions=atoms.get_scaled_positions(wrap=True), + masses=atoms.get_masses(), + ), + symprec=symprec, + supercell_matrix=supercell_matrix, + **phonopy_kwargs, + ) + phonon.generate_displacements(distance=distance) + + return phonon + + +def _get_forces( + phononpy_atoms: PhonopyAtoms, + calculator: BaseCalculator, +) -> np.ndarray: + atoms = Atoms( + symbols=phononpy_atoms.symbols, + cell=phononpy_atoms.cell, + scaled_positions=phononpy_atoms.scaled_positions, + pbc=True, + ) + + atoms.calc = calculator + + return atoms.get_forces() + + +def _generate_task_run_name(): + task_name = task_run.task_name + parameters = task_run.parameters + + atoms = parameters["atoms"] + calculator = parameters["calculator"] + + return ( + f"{task_name}: {atoms.get_chemical_formula()} - {calculator.__class__.__name__}" + ) + + +@task( + name="PHONON", + task_run_name=_generate_task_run_name, + cache_policy=TASK_SOURCE + INPUTS, +) +def run( + atoms: Atoms, + calculator: BaseCalculator, + supercell_matrix: list[int] | None = None, + min_lengths: float | tuple[float, float, float] | None = None, + symprec: float = 1e-5, + distance: float = 0.01, + phonopy_kwargs: dict = {}, + symmetry: bool = False, + t_min: float = 0.0, + t_max: float = 1000.0, + t_step: float = 10.0, + outdir: str | None = None, +): + phonon = get_phonopy( + atoms=atoms, + supercell_matrix=supercell_matrix, + min_lengths=min_lengths, + symprec=symprec, + distance=distance, + phonopy_kwargs=phonopy_kwargs, + ) + + supercells_with_displacements = phonon.supercells_with_displacements + + phonon.forces = [ + _get_forces(supercell, calculator) + for supercell in supercells_with_displacements + if supercell is not None + ] + phonon.produce_force_constants() + + if symmetry: + phonon.symmetrize_force_constants() + phonon.symmetrize_force_constants_by_space_group() + + phonon.run_mesh(with_eigenvectors=True) + phonon.run_total_dos() + phonon.run_thermal_properties(t_step=t_step, t_max=t_max, t_min=t_min) # type: ignore + phonon.auto_band_structure( + write_yaml=True if outdir is not None else False, + filename=Path(outdir, "band.yaml") if outdir is not None else "band.yaml", + ) + if outdir: + phonon.save( + Path(outdir, "phonopy.yaml"), settings={"force_constants": True} + ) + + return { + "phonon": phonon, + } diff --git a/mlip_arena/tasks/utils.py b/mlip_arena/tasks/utils.py index da920f9..e2e4bf1 100644 --- a/mlip_arena/tasks/utils.py +++ b/mlip_arena/tasks/utils.py @@ -2,13 +2,15 @@ from __future__ import annotations +from pprint import pformat + +import torch from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator from ase import units -from ase.calculators.calculator import Calculator, BaseCalculator +from ase.calculators.calculator import BaseCalculator from ase.calculators.mixing import SumCalculator from mlip_arena.models import MLIPEnum -from mlip_arena.models.utils import get_freer_device try: from prefect.logging import get_run_logger @@ -17,16 +19,48 @@ except (ImportError, RuntimeError): from loguru import logger -from pprint import pformat + +def get_freer_device() -> torch.device: + """Get the GPU with the most free memory, or use MPS if available. + s + Returns: + torch.device: The selected GPU device or MPS. + + Raises: + ValueError: If no GPU or MPS is available. + """ + device_count = torch.cuda.device_count() + if device_count > 0: + # If CUDA GPUs are available, select the one with the most free memory + mem_free = [ + torch.cuda.get_device_properties(i).total_memory + - torch.cuda.memory_allocated(i) + for i in range(device_count) + ] + free_gpu_index = mem_free.index(max(mem_free)) + device = torch.device(f"cuda:{free_gpu_index}") + logger.info( + f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs" + ) + elif torch.backends.mps.is_available(): + # If no CUDA GPUs are available but MPS is, use MPS + logger.info("No GPU available. Using MPS.") + device = torch.device("mps") + else: + # Fallback to CPU if neither CUDA GPUs nor MPS are available + logger.info("No GPU or MPS available. Using CPU.") + device = torch.device("cpu") + + return device def get_calculator( - calculator_name: str | MLIPEnum | Calculator | SumCalculator, - calculator_kwargs: dict | None, + calculator_name: str | MLIPEnum | BaseCalculator, + calculator_kwargs: dict | None = None, dispersion: bool = False, dispersion_kwargs: dict | None = None, device: str | None = None, -) -> Calculator | SumCalculator: +) -> BaseCalculator: """Get a calculator with optional dispersion correction.""" device = device or str(get_freer_device()) @@ -40,11 +74,15 @@ def get_calculator( calc = calculator_name.value(**calculator_kwargs) elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name): calc = MLIPEnum[calculator_name].value(**calculator_kwargs) - elif isinstance(calculator_name, type) and issubclass(calculator_name, BaseCalculator): + elif isinstance(calculator_name, type) and issubclass( + calculator_name, BaseCalculator + ): logger.warning(f"Using custom calculator class: {calculator_name}") calc = calculator_name(**calculator_kwargs) - elif isinstance(calculator_name, Calculator | SumCalculator): - logger.warning(f"Using custom calculator object (kwargs are ignored): {calculator_name}") + elif isinstance(calculator_name, BaseCalculator): + logger.warning( + f"Using custom calculator object (kwargs are ignored): {calculator_name}" + ) calc = calculator_name else: raise ValueError(f"Invalid calculator: {calculator_name}") @@ -69,5 +107,5 @@ def get_calculator( if dispersion_kwargs: logger.info(pformat(dispersion_kwargs)) - assert isinstance(calc, Calculator | SumCalculator) + assert isinstance(calc, BaseCalculator) return calc