diff --git a/mlip_arena/data/collate.py b/mlip_arena/data/collate.py new file mode 100644 index 0000000..484cbac --- /dev/null +++ b/mlip_arena/data/collate.py @@ -0,0 +1,201 @@ +import numpy as np +import torch + +# TODO: consider using vesin +from matscipy.neighbours import neighbour_list +from torch_geometric.data import Data + +from ase import Atoms +from ase.calculators.singlepoint import SinglePointCalculator + + +def get_neighbor( + atoms: Atoms, cutoff: float, self_interaction: bool = False +): + pbc = atoms.pbc + cell = atoms.cell.array + + i, j, S = neighbour_list( + quantities="ijS", + pbc=pbc, + cell=cell, + positions=atoms.positions, + cutoff=cutoff + ) + + if not self_interaction: + # Eliminate self-edges that don't cross periodic boundaries + true_self_edge = i == j + true_self_edge &= np.all(S == 0, axis=1) + keep_edge = ~true_self_edge + + i = i[keep_edge] + j = j[keep_edge] + S = S[keep_edge] + + edge_index = np.stack((i, j)).astype(np.int64) + edge_shift = np.dot(S, cell) + + return edge_index, edge_shift + + + +def collate_fn(batch: list[Atoms], cutoff: float) -> Data: + """Collate a list of Atoms objects into a single batched Atoms object.""" + + # Offset the edge indices for each graph to ensure they remain disconnected + offset = 0 + + node_batch = [] + + numbers_batch = [] + positions_batch = [] + # ec_batch = [] + + forces_batch = [] + charges_batch = [] + magmoms_batch = [] + dipoles_batch = [] + + edge_index_batch = [] + edge_shift_batch = [] + + cell_batch = [] + natoms_batch = [] + + energy_batch = [] + stress_batch = [] + + for i, atoms in enumerate(batch): + + edge_index, edge_shift = get_neighbor(atoms, cutoff=cutoff, self_interaction=False) + + edge_index[0] += offset + edge_index[1] += offset + edge_index_batch.append(torch.tensor(edge_index)) + edge_shift_batch.append(torch.tensor(edge_shift)) + + natoms = len(atoms) + offset += natoms + node_batch.append(torch.ones(natoms, dtype=torch.long) * i) + natoms_batch.append(natoms) + + cell_batch.append(torch.tensor(atoms.cell.array)) + numbers_batch.append(torch.tensor(atoms.numbers)) + positions_batch.append(torch.tensor(atoms.positions)) + + # ec_batch.append([Atom(int(a)).elecronic_encoding for a in atoms.numbers]) + + charges_batch.append( + atoms.get_initial_charges() + if atoms.get_initial_charges().any() + else torch.full((natoms,), torch.nan) + ) + magmoms_batch.append( + atoms.get_initial_magnetic_moments() + if atoms.get_initial_magnetic_moments().any() + else torch.full((natoms,), torch.nan) + ) + + # Create the new 'arrays' data for the batch + + cell_batch = torch.stack(cell_batch, dim=0) + node_batch = torch.cat(node_batch, dim=0) + positions_batch = torch.cat(positions_batch, dim=0) + numbers_batch = torch.cat(numbers_batch, dim=0) + natoms_batch = torch.tensor(natoms_batch, dtype=torch.long) + + charges_batch = torch.cat(charges_batch, dim=0) if charges_batch else None + magmoms_batch = torch.cat(magmoms_batch, dim=0) if magmoms_batch else None + + # ec_batch = list(map(lambda a: Atom(int(a)).elecronic_encoding, numbers_batch)) + # ec_batch = torch.stack(ec_batch, dim=0) + + edge_index_batch = torch.cat(edge_index_batch, dim=1) + edge_shift_batch = torch.cat(edge_shift_batch, dim=0) + + arrays_batch_concatenated = { + "cell": cell_batch, + "positions": positions_batch, + "edge_index": edge_index_batch, + "edge_shift": edge_shift_batch, + "numbers": numbers_batch, + "num_nodes": offset, + "batch": node_batch, + "charges": charges_batch, + "magmoms": magmoms_batch, + # "ec": ec_batch, + "natoms": natoms_batch, + "cutoff": torch.tensor(cutoff), + } + + # TODO: custom fields + + # Create a new Data object with the concatenated arrays data + batch_data = Data.from_dict(arrays_batch_concatenated) + + return batch_data + + +def decollate_fn(batch_data: Data) -> list[Atoms]: + """Decollate a batched Data object into a list of individual Atoms objects.""" + + # FIXME: this function is not working properly when the batch_data is on GPU. + # TODO: create a new Cell class using torch tensor to handle device placement. + # As a temporary fix, detach the batch_data from the GPU and move it to CPU. + batch_data = batch_data.detach().cpu() + + # Initialize empty lists to store individual data entries + individual_entries = [] + + # Split the 'batch' attribute to identify data entries + unique_batches = batch_data.batch.unique(sorted=True) + + for i in unique_batches: + # Identify the indices corresponding to the current data entry + entry_indices = (batch_data.batch == i).nonzero(as_tuple=True)[0] + + # Extract the attributes for the current data entry + cell = batch_data.cell[i] + numbers = batch_data.numbers[entry_indices] + positions = batch_data.positions[entry_indices] + # edge_index = batch_data.edge_index[:, entry_indices] + # edge_shift = batch_data.edge_shift[entry_indices] + # batch_data.ec[entry_indices] if batch_data.ec is not None else None + + # Optional fields + energy = batch_data.energy[i] if "energy" in batch_data else None + forces = batch_data.forces[entry_indices] if "forces" in batch_data else None + stress = batch_data.stress[i] if "stress" in batch_data else None + + # charges = batch_data.charges[entry_indices] if "charges" in batch_data else None + # magmoms = batch_data.magmoms[entry_indices] if "magmoms" in batch_data else None + # dipoles = batch_data.dipoles[entry_indices] if "dipoles" in batch_data else None + + # TODO: cumstom fields + + # Create an 'Atoms' object for the current data entry + atoms = Atoms( + cell=cell, + positions=positions, + numbers=numbers, + # forces=None if torch.any(torch.isnan(forces)) else forces, + # charges=None if torch.any(torch.isnan(charges)) else charges, + # magmoms=None if torch.any(torch.isnan(magmoms)) else magmoms, + # dipoles=None if torch.any(torch.isnan(dipoles)) else dipoles, + # energy=None if torch.isnan(energy) else energy, + # stress=None if torch.any(torch.isnan(stress)) else stress, + ) + + atoms.calc = SinglePointCalculator( + energy=energy, + forces=forces, + stress=stress, + # charges=charges, + # magmoms=magmoms, + ) # type: ignore + + # Append the individual data entry to the list + individual_entries.append(atoms) + + return individual_entries diff --git a/mlip_arena/models/__init__.py b/mlip_arena/models/__init__.py index d963e5f..5238fdd 100644 --- a/mlip_arena/models/__init__.py +++ b/mlip_arena/models/__init__.py @@ -6,11 +6,21 @@ import torch import yaml -from ase import Atoms -from ase.calculators.calculator import Calculator, all_changes from huggingface_hub import PyTorchModelHubMixin from torch import nn +from ase import Atoms +from ase.calculators.calculator import Calculator, all_changes +from mlip_arena.data.collate import collate_fn +from mlip_arena.models.utils import get_freer_device + +try: + from prefect.logging import get_run_logger + + logger = get_run_logger() +except (ImportError, RuntimeError): + from loguru import logger + # from torch_geometric.data import Data with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f: @@ -20,14 +30,17 @@ for model, metadata in REGISTRY.items(): try: - module = importlib.import_module(f"{__package__}.{metadata['module']}.{metadata['family']}") + module = importlib.import_module( + f"{__package__}.{metadata['module']}.{metadata['family']}" + ) MLIPMap[model] = getattr(module, metadata["class"]) except (ModuleNotFoundError, AttributeError, ValueError) as e: - print(e) + logger.warning(e) continue MLIPEnum = Enum("MLIPEnum", MLIPMap) + class MLIP( nn.Module, PyTorchModelHubMixin, @@ -35,6 +48,9 @@ class MLIP( ): def __init__(self, model: nn.Module) -> None: super().__init__() + # https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/_dynamo/config.py#L534 + # torch._dynamo.config.compiled_autograd = True + # self.model = torch.compile(model) self.model = model def forward(self, x): @@ -47,7 +63,9 @@ class MLIPCalculator(MLIP, Calculator): def __init__( self, - model, + model: nn.Module, + device: torch.device | None = None, + cutoff: float = 6.0, # ASE Calculator restart=None, atoms=None, @@ -60,12 +78,24 @@ def __init__( ) # Initialize ASE Calculator part # Additional initialization if needed # self.name: str = self.__class__.__name__ + self.device = device or get_freer_device() + self.cutoff = cutoff + self.model.to(self.device) # self.device = device or torch.device( # "cuda" if torch.cuda.is_available() else "cpu" # ) # self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device) # self.implemented_properties = ["energy", "forces", "stress"] + # def __getstate__(self): + # state = self.__dict__.copy() + # state["_modules"]["model"] = state["_modules"]["model"]._orig_mod + # return state + + # def __setstate__(self, state): + # self.__dict__.update(state) + # self.model = torch.compile(state["_modules"]["model"]) + def calculate( self, atoms: Atoms, @@ -75,7 +105,11 @@ def calculate( """Calculate energies and forces for the given Atoms object""" super().calculate(atoms, properties, system_changes) - output = self.forward(atoms) + # TODO: move collate_fn to here in MLIPCalculator + data = collate_fn([atoms], cutoff=self.cutoff).to(self.device) + output = self.forward(data) + + # TODO: decollate_fn self.results = {} if "energy" in properties: @@ -85,13 +119,14 @@ def calculate( if "stress" in properties: self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy() - def forward(self, x: Atoms) -> dict[str, torch.Tensor]: - """Implement data conversion, graph creation, and model forward pass + # def forward(self, x: Atoms) -> dict[str, torch.Tensor]: + # """Implement data conversion, graph creation, and model forward pass + + # Example implementation: + # 1. Use `ase.neighborlist.NeighborList` to get neighbor list + # 2. Create `torch_geometric.data.Data` object and copy the data + # 3. Pass the `Data` object to the model and return the output - Example implementation: - 1. Use `ase.neighborlist.NeighborList` to get neighbor list - 2. Create `torch_geometric.data.Data` object and copy the data - 3. Pass the `Data` object to the model and return the output + # """ - """ - raise NotImplementedError + # raise NotImplementedError diff --git a/mlip_arena/models/classicals/__init__.py b/mlip_arena/models/classicals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlip_arena/models/classicals/zbl.py b/mlip_arena/models/classicals/zbl.py new file mode 100644 index 0000000..5821539 --- /dev/null +++ b/mlip_arena/models/classicals/zbl.py @@ -0,0 +1,214 @@ +import torch +import torch.linalg as LA +import torch.nn as nn +import torch_scatter +from torch_geometric.data import Data + +from ase.data import covalent_radii +from ase.units import _e, _eps0, m, pi +from e3nn.util.jit import compile_mode # TODO: e3nn allows autograd in compiled model + + +@compile_mode("script") +class ZBL(nn.Module): + """Ziegler-Biersack-Littmark (ZBL) screened nuclear repulsion""" + + def __init__( + self, + trianable: bool = False, + **kwargs, + ) -> None: + nn.Module.__init__(self, **kwargs) + + torch.set_default_dtype(torch.double) + + self.a = torch.nn.parameter.Parameter( + torch.tensor( + [0.18175, 0.50986, 0.28022, 0.02817], dtype=torch.get_default_dtype() + ), + requires_grad=trianable, + ) + self.b = torch.nn.parameter.Parameter( + torch.tensor( + [-3.19980, -0.94229, -0.40290, -0.20162], + dtype=torch.get_default_dtype(), + ), + requires_grad=trianable, + ) + + self.a0 = torch.nn.parameter.Parameter( + torch.tensor(0.46850, dtype=torch.get_default_dtype()), + requires_grad=trianable, + ) + + self.p = torch.nn.parameter.Parameter( + torch.tensor(0.23, dtype=torch.get_default_dtype()), requires_grad=trianable + ) + + self.register_buffer( + "covalent_radii", + torch.tensor( + covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + + def phi(self, x): + return torch.einsum("i,ij->j", self.a, torch.exp(torch.outer(self.b, x))) + + def d_phi(self, x): + return torch.einsum( + "i,ij->j", self.a * self.b, torch.exp(torch.outer(self.b, x)) + ) + + def dd_phi(self, x): + return torch.einsum( + "i,ij->j", self.a * self.b**2, torch.exp(torch.outer(self.b, x)) + ) + + def eij( + self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor + ) -> torch.Tensor: # [eV] + return _e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij) + + def d_eij( + self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor + ) -> torch.Tensor: # [eV / A] + return -_e * m / (4 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**2) + + def dd_eij( + self, zi: torch.Tensor, zj: torch.Tensor, rij: torch.Tensor + ) -> torch.Tensor: # [eV / A^2] + return _e * m / (2 * pi * _eps0) * torch.div(torch.mul(zi, zj), rij**3) + + def switch_fn( + self, + zi: torch.Tensor, + zj: torch.Tensor, + rij: torch.Tensor, + aij: torch.Tensor, + router: torch.Tensor, + rinner: torch.Tensor, + ) -> torch.Tensor: # [eV] + # aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p)) + + xrouter = router / aij + + energy = self.eij(zi, zj, router) * self.phi(xrouter) + + grad1 = self.d_eij(zi, zj, router) * self.phi(xrouter) + self.eij( + zi, zj, router + ) * self.d_phi(xrouter) + + grad2 = ( + self.dd_eij(zi, zj, router) * self.phi(xrouter) + + self.d_eij(zi, zj, router) * self.d_phi(xrouter) + + self.d_eij(zi, zj, router) * self.d_phi(xrouter) + + self.eij(zi, zj, router) * self.dd_phi(xrouter) + ) + + A = (-3 * grad1 + (router - rinner) * grad2) / (router - rinner) ** 2 + B = (2 * grad1 - (router - rinner) * grad2) / (router - rinner) ** 3 + C = ( + -energy + + 1.0 / 2.0 * (router - rinner) * grad1 + - 1.0 / 12.0 * (router - rinner) ** 2 * grad2 + ) + + switching = torch.where( + rij < rinner, + C, + A / 3.0 * (rij - rinner) ** 3 + B / 4.0 * (rij - rinner) ** 4 + C, + ) + + return switching + + def envelope(self, r: torch.Tensor, rc: torch.Tensor, p: int = 6): + x = r / rc + y = ( + 1.0 + - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p) + + p * (p + 2.0) * torch.pow(x, p + 1) + - (p * (p + 1.0) / 2) * torch.pow(x, p + 2) + ) * (x < 1) + return y + + def _get_derivatives(self, energy: torch.Tensor, data: Data): + egradi, egradij = torch.autograd.grad( + outputs=[energy], # TODO: generalized derivatives + inputs=[data.positions, data.vij], # TODO: generalized derivatives + grad_outputs=[torch.ones_like(energy)], + retain_graph=True, + create_graph=True, + allow_unused=True, + ) + + volume = torch.det(data.cell) # (batch,) + rfaxy = torch.einsum("ax,ay->axy", data.vij, -egradij) + + edge_batch = data.batch[data.edge_index[0]] + + stress = ( + -0.5 + * torch_scatter.scatter_sum(rfaxy, edge_batch, dim=0) + / volume.view(-1, 1) + ) + + return -egradi, stress + + def forward( + self, + data: Data, + ) -> dict[str, torch.Tensor]: + # TODO: generalized derivatives + data.positions.requires_grad_(True) + + numbers = data.numbers # (sum(N), ) + positions = data.positions # (sum(N), 3) + edge_index = data.edge_index # (2, sum(E)) + edge_shift = data.edge_shift # (sum(E), 3) + batch = data.batch # (sum(N), ) + + edge_src, edge_dst = edge_index[0], edge_index[1] + + if "rij" not in data or "vij" not in data: + data.vij = positions[edge_dst] - positions[edge_src] + edge_shift + data.rij = LA.norm(data.vij, dim=-1) + + rbond = ( + self.covalent_radii[numbers[edge_src]] + + self.covalent_radii[numbers[edge_dst]] + ) + + rij = data.rij + zi = numbers[edge_src] # (sum(E), ) + zj = numbers[edge_dst] # (sum(E), ) + + aij = self.a0 / (torch.pow(zi, self.p) + torch.pow(zj, self.p)) # (sum(E), ) + + energy_pairs = ( + self.eij(zi, zj, rij) + * self.phi(rij / aij.to(rij)) + * self.envelope(rij, torch.min(data.cutoff, rbond)) + ) + + energy_nodes = 0.5 * torch_scatter.scatter_add( + src=energy_pairs, + index=edge_dst, + dim=0, + ) # (sum(N), ) + + energies = torch_scatter.scatter_add( + src=energy_nodes, + index=batch, + dim=0, + ) # (B, ) + + # TODO: generalized derivatives + forces, stress = self._get_derivatives(energies, data) + + return { + "energy": energies, + "forces": forces, + "stress": stress, + } diff --git a/mlip_arena/models/registry.yaml b/mlip_arena/models/registry.yaml index b422e11..dfb3e62 100644 --- a/mlip_arena/models/registry.yaml +++ b/mlip_arena/models/registry.yaml @@ -84,6 +84,7 @@ MatterSim: - eos_alloy gpu-tasks: - homonuclear-diatomics + - stability github: https://github.com/microsoft/mattersim doi: https://arxiv.org/abs/2405.04967 date: 2024-12-05 @@ -264,6 +265,7 @@ ALIGNN: - MP22 gpu-tasks: - homonuclear-diatomics + - stability # - combustion prediction: EFS nvt: true @@ -309,6 +311,7 @@ ORBv2: gpu-tasks: - homonuclear-diatomics - combustion + - stability github: https://github.com/orbital-materials/orb-models doi: date: 2024-10-15 diff --git a/mlip_arena/tasks/optimize.py b/mlip_arena/tasks/optimize.py index 49db78c..8c175a2 100644 --- a/mlip_arena/tasks/optimize.py +++ b/mlip_arena/tasks/optimize.py @@ -111,6 +111,9 @@ def run( logger.info(f"Criterion: {pformat(criterion)}") optimizer_instance.run(**criterion) + return { "atoms": atoms, + "steps": optimizer_instance.nsteps, + "converged": optimizer_instance.converged(), } diff --git a/tests/test_internal_calculators.py b/tests/test_internal_calculators.py new file mode 100644 index 0000000..d094897 --- /dev/null +++ b/tests/test_internal_calculators.py @@ -0,0 +1,36 @@ +import numpy as np +from mlip_arena.models import MLIPCalculator +from mlip_arena.models.classicals.zbl import ZBL + +from ase.build import bulk + + +def test_zbl(): + calc = MLIPCalculator(model=ZBL(), cutoff=6.0) + + energies = [] + forces = [] + stresses = [] + + lattice_constants = [1, 3, 5, 7] + + for a in lattice_constants: + atoms = bulk("Cu", "fcc", a=a) * (2, 2, 2) + atoms.calc = calc + + energies.append(atoms.get_potential_energy()) + forces.append(atoms.get_forces()) + stresses.append(atoms.get_stress(voigt=False)) + + # test energy monotonicity + assert all(np.diff(energies) <= 0), "Energy is not monotonically decreasing with increasing lattice constant" + + # test force vectors are all zeros due to symmetry + for f in forces: + assert np.allclose(f, 0), "Forces should be zero due to symmetry" + + # test trace of stress is monotonically increasing (less negative) and zero beyond cutoff + traces = [np.trace(s) for s in stresses] + + assert all(np.diff(traces) >= 0), "Trace of stress is not monotonically increasing with increasing lattice constant" + assert np.allclose(stresses[-1], 0), "Stress should be zero beyond cutoff"