Skip to content

Commit

Permalink
Always cast dataset to the chosen float type (allows to train in fp32
Browse files Browse the repository at this point in the history
a dataset that provides fp64)
  • Loading branch information
RaulPPelaez committed Jan 18, 2024
1 parent 592ca86 commit 5ab3aed
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class FloatCastDatasetWrapper(Dataset):
"""A wrapper around a torch_geometric dataset that casts all floating point
tensors to a given dtype.
"""

def __init__(self, dataset, dtype=torch.float64):
super(FloatCastDatasetWrapper, self).__init__(
dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter
Expand Down Expand Up @@ -79,15 +80,16 @@ def setup(self, stage):
if self.hparams["dataset_arg"] is not None:
dataset_arg = self.hparams["dataset_arg"]
if self.hparams["dataset"] == "HDF5":
dataset_arg["dataset_preload_limit"] = self.hparams["dataset_preload_limit"]
dataset_arg["dataset_preload_limit"] = self.hparams[
"dataset_preload_limit"
]
self.dataset = getattr(datasets, self.hparams["dataset"])(
self.hparams["dataset_root"], **dataset_arg
)

if self.hparams["precision"] != 32:
self.dataset = FloatCastDatasetWrapper(
self.dataset, dtype_mapping[self.hparams["precision"]]
)
self.dataset = FloatCastDatasetWrapper(
self.dataset, dtype_mapping[self.hparams["precision"]]
)

self.idx_train, self.idx_val, self.idx_test = make_splits(
len(self.dataset),
Expand Down

0 comments on commit 5ab3aed

Please sign in to comment.