Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convenient ZBL torch calculator #44

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions mlip_arena/data/collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import numpy as np
import torch

# TODO: consider using vesin
from matscipy.neighbours import neighbour_list
from torch_geometric.data import Data

from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator


def get_neighbor(
atoms: Atoms, cutoff: float, self_interaction: bool = False
):
pbc = atoms.pbc
cell = atoms.cell.array

i, j, S = neighbour_list(
quantities="ijS",
pbc=pbc,
cell=cell,
positions=atoms.positions,
cutoff=cutoff
)

if not self_interaction:
# Eliminate self-edges that don't cross periodic boundaries
true_self_edge = i == j
true_self_edge &= np.all(S == 0, axis=1)
keep_edge = ~true_self_edge

i = i[keep_edge]
j = j[keep_edge]
S = S[keep_edge]

edge_index = np.stack((i, j)).astype(np.int64)
edge_shift = np.dot(S, cell)

return edge_index, edge_shift



def collate_fn(batch: list[Atoms], cutoff: float) -> Data:
"""Collate a list of Atoms objects into a single batched Atoms object."""

# Offset the edge indices for each graph to ensure they remain disconnected
offset = 0

node_batch = []

numbers_batch = []
positions_batch = []
# ec_batch = []

forces_batch = []
charges_batch = []
magmoms_batch = []
dipoles_batch = []

edge_index_batch = []
edge_shift_batch = []

cell_batch = []
natoms_batch = []

energy_batch = []
stress_batch = []

for i, atoms in enumerate(batch):

edge_index, edge_shift = get_neighbor(atoms, cutoff=cutoff, self_interaction=False)

edge_index[0] += offset
edge_index[1] += offset
edge_index_batch.append(torch.tensor(edge_index))
edge_shift_batch.append(torch.tensor(edge_shift))

natoms = len(atoms)
offset += natoms
node_batch.append(torch.ones(natoms, dtype=torch.long) * i)
natoms_batch.append(natoms)

cell_batch.append(torch.tensor(atoms.cell.array))
numbers_batch.append(torch.tensor(atoms.numbers))
positions_batch.append(torch.tensor(atoms.positions))

# ec_batch.append([Atom(int(a)).elecronic_encoding for a in atoms.numbers])

charges_batch.append(
atoms.get_initial_charges()
if atoms.get_initial_charges().any()
else torch.full((natoms,), torch.nan)
)
magmoms_batch.append(
atoms.get_initial_magnetic_moments()
if atoms.get_initial_magnetic_moments().any()
else torch.full((natoms,), torch.nan)
)

# Create the new 'arrays' data for the batch

cell_batch = torch.stack(cell_batch, dim=0)
node_batch = torch.cat(node_batch, dim=0)
positions_batch = torch.cat(positions_batch, dim=0)
numbers_batch = torch.cat(numbers_batch, dim=0)
natoms_batch = torch.tensor(natoms_batch, dtype=torch.long)

charges_batch = torch.cat(charges_batch, dim=0) if charges_batch else None
magmoms_batch = torch.cat(magmoms_batch, dim=0) if magmoms_batch else None

# ec_batch = list(map(lambda a: Atom(int(a)).elecronic_encoding, numbers_batch))
# ec_batch = torch.stack(ec_batch, dim=0)

edge_index_batch = torch.cat(edge_index_batch, dim=1)
edge_shift_batch = torch.cat(edge_shift_batch, dim=0)

arrays_batch_concatenated = {
"cell": cell_batch,
"positions": positions_batch,
"edge_index": edge_index_batch,
"edge_shift": edge_shift_batch,
"numbers": numbers_batch,
"num_nodes": offset,
"batch": node_batch,
"charges": charges_batch,
"magmoms": magmoms_batch,
# "ec": ec_batch,
"natoms": natoms_batch,
"cutoff": torch.tensor(cutoff),
}

# TODO: custom fields

# Create a new Data object with the concatenated arrays data
batch_data = Data.from_dict(arrays_batch_concatenated)

return batch_data


def decollate_fn(batch_data: Data) -> list[Atoms]:
"""Decollate a batched Data object into a list of individual Atoms objects."""

# FIXME: this function is not working properly when the batch_data is on GPU.
# TODO: create a new Cell class using torch tensor to handle device placement.
# As a temporary fix, detach the batch_data from the GPU and move it to CPU.
batch_data = batch_data.detach().cpu()

# Initialize empty lists to store individual data entries
individual_entries = []

# Split the 'batch' attribute to identify data entries
unique_batches = batch_data.batch.unique(sorted=True)

for i in unique_batches:
# Identify the indices corresponding to the current data entry
entry_indices = (batch_data.batch == i).nonzero(as_tuple=True)[0]

# Extract the attributes for the current data entry
cell = batch_data.cell[i]
numbers = batch_data.numbers[entry_indices]
positions = batch_data.positions[entry_indices]
# edge_index = batch_data.edge_index[:, entry_indices]
# edge_shift = batch_data.edge_shift[entry_indices]
# batch_data.ec[entry_indices] if batch_data.ec is not None else None

# Optional fields
energy = batch_data.energy[i] if "energy" in batch_data else None
forces = batch_data.forces[entry_indices] if "forces" in batch_data else None
stress = batch_data.stress[i] if "stress" in batch_data else None

# charges = batch_data.charges[entry_indices] if "charges" in batch_data else None
# magmoms = batch_data.magmoms[entry_indices] if "magmoms" in batch_data else None
# dipoles = batch_data.dipoles[entry_indices] if "dipoles" in batch_data else None

# TODO: cumstom fields

# Create an 'Atoms' object for the current data entry
atoms = Atoms(
cell=cell,
positions=positions,
numbers=numbers,
# forces=None if torch.any(torch.isnan(forces)) else forces,
# charges=None if torch.any(torch.isnan(charges)) else charges,
# magmoms=None if torch.any(torch.isnan(magmoms)) else magmoms,
# dipoles=None if torch.any(torch.isnan(dipoles)) else dipoles,
# energy=None if torch.isnan(energy) else energy,
# stress=None if torch.any(torch.isnan(stress)) else stress,
)

atoms.calc = SinglePointCalculator(
energy=energy,
forces=forces,
stress=stress,
# charges=charges,
# magmoms=magmoms,
) # type: ignore

# Append the individual data entry to the list
individual_entries.append(atoms)

return individual_entries
63 changes: 49 additions & 14 deletions mlip_arena/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,21 @@

import torch
import yaml
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from mlip_arena.data.collate import collate_fn
from mlip_arena.models.utils import get_freer_device

try:
from prefect.logging import get_run_logger

logger = get_run_logger()
except (ImportError, RuntimeError):
from loguru import logger

# from torch_geometric.data import Data

with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
Expand All @@ -20,21 +30,27 @@

for model, metadata in REGISTRY.items():
try:
module = importlib.import_module(f"{__package__}.{metadata['module']}.{metadata['family']}")
module = importlib.import_module(
f"{__package__}.{metadata['module']}.{metadata['family']}"
)
MLIPMap[model] = getattr(module, metadata["class"])
except (ModuleNotFoundError, AttributeError, ValueError) as e:
print(e)
logger.warning(e)
continue

MLIPEnum = Enum("MLIPEnum", MLIPMap)


class MLIP(
nn.Module,
PyTorchModelHubMixin,
tags=["atomistic-simulation", "MLIP"],
):
def __init__(self, model: nn.Module) -> None:
super().__init__()
# https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/_dynamo/config.py#L534
# torch._dynamo.config.compiled_autograd = True
# self.model = torch.compile(model)
self.model = model

def forward(self, x):
Expand All @@ -47,7 +63,9 @@ class MLIPCalculator(MLIP, Calculator):

def __init__(
self,
model,
model: nn.Module,
device: torch.device | None = None,
cutoff: float = 6.0,
# ASE Calculator
restart=None,
atoms=None,
Expand All @@ -60,12 +78,24 @@ def __init__(
) # Initialize ASE Calculator part
# Additional initialization if needed
# self.name: str = self.__class__.__name__
self.device = device or get_freer_device()
self.cutoff = cutoff
self.model.to(self.device)
# self.device = device or torch.device(
# "cuda" if torch.cuda.is_available() else "cpu"
# )
# self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
# self.implemented_properties = ["energy", "forces", "stress"]

# def __getstate__(self):
# state = self.__dict__.copy()
# state["_modules"]["model"] = state["_modules"]["model"]._orig_mod
# return state

# def __setstate__(self, state):
# self.__dict__.update(state)
# self.model = torch.compile(state["_modules"]["model"])

def calculate(
self,
atoms: Atoms,
Expand All @@ -75,7 +105,11 @@ def calculate(
"""Calculate energies and forces for the given Atoms object"""
super().calculate(atoms, properties, system_changes)

output = self.forward(atoms)
# TODO: move collate_fn to here in MLIPCalculator
data = collate_fn([atoms], cutoff=self.cutoff).to(self.device)
output = self.forward(data)

# TODO: decollate_fn

self.results = {}
if "energy" in properties:
Expand All @@ -85,13 +119,14 @@ def calculate(
if "stress" in properties:
self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()

def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
"""Implement data conversion, graph creation, and model forward pass
# def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
# """Implement data conversion, graph creation, and model forward pass

# Example implementation:
# 1. Use `ase.neighborlist.NeighborList` to get neighbor list
# 2. Create `torch_geometric.data.Data` object and copy the data
# 3. Pass the `Data` object to the model and return the output

Example implementation:
1. Use `ase.neighborlist.NeighborList` to get neighbor list
2. Create `torch_geometric.data.Data` object and copy the data
3. Pass the `Data` object to the model and return the output
# """

"""
raise NotImplementedError
# raise NotImplementedError
Empty file.
Loading
Loading