Skip to content

Commit

Permalink
Merge pull request #257 from torchmd/comp6v2
Browse files Browse the repository at this point in the history
added comp6v2 dataset
  • Loading branch information
stefdoerr authored Jan 24, 2024
2 parents 07424e1 + f01ee90 commit 8e113eb
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 2 deletions.
12 changes: 11 additions & 1 deletion torchmdnet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@

from .ace import Ace
from .ani import ANI1, ANI1CCX, ANI1X, ANI2X
from .comp6 import ANIMD, DrugBank, GDB07to09, GDB10to13, Tripeptides, S66X8, COMP6v1
from .comp6 import (
ANIMD,
DrugBank,
GDB07to09,
GDB10to13,
Tripeptides,
S66X8,
COMP6v1,
COMP6v2,
)
from .custom import Custom
from .water import WaterBox
from .hdf import HDF5
Expand All @@ -22,6 +31,7 @@
"ANI1X",
"ANI2X",
"COMP6v1",
"COMP6v2",
"Custom",
"DrugBank",
"GDB07to09",
Expand Down
96 changes: 95 additions & 1 deletion torchmdnet/datasets/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import h5py
import numpy as np
import torch as pt
from torch_geometric.data import Data, Dataset, download_url
from torch_geometric.data import Data, Dataset, download_url, extract_tar
from torchmdnet.datasets.memdataset import MemmappedDataset
from tqdm import tqdm
from torchmdnet.datasets.ani import ANIBase
import os

"""
COmprehensive Machine-learning Potential (COMP6) Benchmark Suite
Expand Down Expand Up @@ -313,3 +315,95 @@ def len(self):
def get(self, idx):
i_subset, i_sample = self.subset_indices[idx]
return self.subsets[i_subset][i_sample]


class COMP6v2(ANIBase):
"""Dataset for the COmprehensive Machine-learning Potential (COMP6) Benchmark Suite version 2.0
COMP6v2 is a data set of density functional properties for molecules containing H, C, N, O, S, F, and Cl.
It is available at different levels of theory but here we use wB97X/631Gd which was used in evaluating ANI-2x.
References:
- https://pubs.acs.org/doi/10.1021/acs.jctc.0c00121
"""

# Taken from https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat
_ELEMENT_ENERGIES = {
1: -0.5978583943827134, # H
6: -38.08933878049795, # C
7: -54.711968298621066, # N
8: -75.19106774742086, # O
9: -99.80348506781634, # F
16: -398.1577125334925, # S
17: -460.1681939421027, # Cl
}

@property
def raw_url(self):
return "https://zenodo.org/records/10126157/files/COMP6v2_wB97X-631Gd.tar.gz"

@property
def raw_file_names(self):
return [os.path.join("comp6v2_final_h5", "COMP6v2_wB97X-631Gd.h5")]

def download(self):
archive = download_url(self.raw_url, self.raw_dir)
extract_tar(archive, self.raw_dir)
os.remove(archive)

def sample_iter(self, mol_ids=False):
"""
In [14]: list(molecules)
Out[14]:
[('cm5_atomic_charges', <HDF5 dataset "cm5_atomic_charges": shape (128, 313), type "<f4">),
('coordinates', <HDF5 dataset "coordinates": shape (128, 312, 3), type "<f4">),
('energies', <HDF5 dataset "energies": shape (128,), type "<f8">),
('forces', <HDF5 dataset "forces": shape (128, 312, 3), type "<f4">),
('hirshfeld_atomic_charges', <HDF5 dataset "hirshfeld_atomic_charges": shape (128, 313), type "<f4">),
('hirshfeld_atomic_dipoles', <HDF5 dataset "hirshfeld_atomic_dipoles": shape (128, 313, 3), type "<f4">),
('species', <HDF5 dataset "species": shape (128, 312), type "<i8">)]
"""
assert len(self.raw_paths) == 1
with h5py.File(self.raw_paths[0]) as h5data:
for key, data in tqdm(h5data.items(), desc="Molecule Group", leave=False):
all_z = pt.tensor(data["species"][:], dtype=pt.long)
all_pos = pt.tensor(data["coordinates"][:], dtype=pt.float32)
all_y = pt.tensor(
data["energies"][:] * self.HARTREE_TO_EV, dtype=pt.float64
)
all_neg_dy = pt.tensor(
data["forces"][:] * self.HARTREE_TO_EV, dtype=pt.float32
)
n_mols = all_pos.shape[0]
n_atoms = all_pos.shape[1]

assert all_y.shape[0] == n_mols
assert all_z.shape == (n_mols, n_atoms)
assert all_pos.shape == (n_mols, n_atoms, 3)
assert all_neg_dy.shape == (n_mols, n_atoms, 3)

for i, (pos, y, z, neg_dy) in enumerate(
zip(all_pos, all_y, all_z, all_neg_dy)
):
# Create a sample
args = dict(z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy)
if mol_ids:
args["mol_id"] = f"{key}_{i}"
data = Data(**args)

if data := self.filter_and_pre_transform(data):
yield data

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV

return refs.view(-1, 1)

# Circumvent https://github.com/pyg-team/pytorch_geometric/issues/4567
# TODO remove when fixed
def process(self):
super().process()

0 comments on commit 8e113eb

Please sign in to comment.