Skip to content

Commit

Permalink
Merge pull request #258 from torchmd/fix_processed_paths
Browse files Browse the repository at this point in the history
correct paths for processed files. bugfix for last PR
  • Loading branch information
RaulPPelaez authored Jan 24, 2024
2 parents 625b655 + fce5008 commit 07424e1
Showing 1 changed file with 39 additions and 29 deletions.
68 changes: 39 additions & 29 deletions torchmdnet/datasets/memdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,42 +48,50 @@ def __init__(
self.properties = properties
super().__init__(root, transform, pre_transform, pre_filter)

self.idx_mm = np.memmap(self.fname("idx"), mode="r", dtype=np.int64)
self.z_mm = np.memmap(self.fname("z"), mode="r", dtype=np.int8)
fnames = self.processed_paths_dict

self.idx_mm = np.memmap(fnames["idx"], mode="r", dtype=np.int64)
self.z_mm = np.memmap(fnames["z"], mode="r", dtype=np.int8)
num_all_confs = self.idx_mm.shape[0] - 1
num_all_atoms = self.z_mm.shape[0]
self.pos_mm = np.memmap(
self.fname("pos"), mode="r", dtype=np.float32, shape=(num_all_atoms, 3)
fnames["pos"], mode="r", dtype=np.float32, shape=(num_all_atoms, 3)
)
if "y" in self.properties:
self.y_mm = np.memmap(self.fname("y"), mode="r", dtype=np.float64)
self.y_mm = np.memmap(fnames["y"], mode="r", dtype=np.float64)
if "neg_dy" in self.properties:
neg_dy_name = self.fname("neg_dy")
self.neg_dy_mm = np.memmap(
neg_dy_name, mode="r", dtype=np.float32, shape=(num_all_atoms, 3)
fnames["neg_dy"], mode="r", dtype=np.float32, shape=(num_all_atoms, 3)
)
if "q" in self.properties:
self.q_mm = np.memmap(self.fname("q"), mode="r", dtype=np.int8)
self.q_mm = np.memmap(fnames["q"], mode="r", dtype=np.int8)
if "pq" in self.properties:
self.pq_mm = np.memmap(self.fname("pq"), mode="r", dtype=np.float32)
self.pq_mm = np.memmap(fnames["pq"], mode="r", dtype=np.float32)
if "dp" in self.properties:
self.dp_mm = np.memmap(
self.fname("dp"), mode="r", dtype=np.float32, shape=(num_all_confs, 3)
fnames["dp"], mode="r", dtype=np.float32, shape=(num_all_confs, 3)
)

assert self.idx_mm[0] == 0
assert self.idx_mm[-1] == len(self.z_mm)
assert len(self.idx_mm) == len(self.y_mm) + 1

def fname(self, prop):
return f"{self.name}.{prop}.mmap"

@property
def processed_file_names(self):
return [
self.fname(prop) for prop in ["idx", "z", "pos"] + list(self.properties)
f"{self.name}.{prop}.mmap"
for prop in ["idx", "z", "pos"] + list(self.properties)
]

@property
def processed_paths_dict(self):
return {
prop: fname
for prop, fname in zip(
["idx", "z", "pos"] + list(self.properties), self.processed_paths
)
}

@staticmethod
def compute_reference_energy(self):
raise NotImplementedError
Expand All @@ -103,49 +111,51 @@ def process(self):
print(f" Total number of atoms: {num_all_atoms}")
print(f" Properties available: {self.properties}")

fnames = self.processed_paths_dict

idx_mm = np.memmap(
self.fname("idx") + ".tmp",
fnames["idx"] + ".tmp",
mode="w+",
dtype=np.int64,
shape=(num_all_confs + 1,),
)
z_mm = np.memmap(
self.fname("z") + ".tmp", mode="w+", dtype=np.int8, shape=(num_all_atoms,)
fnames["z"] + ".tmp", mode="w+", dtype=np.int8, shape=(num_all_atoms,)
)
pos_mm = np.memmap(
self.fname("pos") + ".tmp",
fnames["pos"] + ".tmp",
mode="w+",
dtype=np.float32,
shape=(num_all_atoms, 3),
)
if "y" in self.properties:
y_mm = np.memmap(
self.fname("y") + ".tmp",
fnames["y"] + ".tmp",
mode="w+",
dtype=np.float64,
shape=(num_all_confs,),
)
if "neg_dy" in self.properties:
neg_dy_mm = np.memmap(
self.fname("neg_dy") + ".tmp",
fnames["neg_dy"] + ".tmp",
mode="w+",
dtype=np.float32,
shape=(num_all_atoms, 3),
)
if "q" in self.properties:
q_mm = np.memmap(
self.fname("q") + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs
fnames["q"] + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs
)
if "pq" in self.properties:
pq_mm = np.memmap(
self.fname("pq") + ".tmp",
fnames["pq"] + ".tmp",
mode="w+",
dtype=np.float32,
shape=num_all_atoms,
)
if "dp" in self.properties:
dp_mm = np.memmap(
self.fname("dp") + ".tmp",
fnames["dp"] + ".tmp",
mode="w+",
dtype=np.float32,
shape=(num_all_confs, 3),
Expand Down Expand Up @@ -188,19 +198,19 @@ def process(self):
if "dp" in self.properties:
dp_mm.flush()

os.rename(idx_mm.filename, self.fname("idx"))
os.rename(z_mm.filename, self.fname("z"))
os.rename(pos_mm.filename, self.fname("pos"))
os.rename(idx_mm.filename, fnames["idx"])
os.rename(z_mm.filename, fnames["z"])
os.rename(pos_mm.filename, fnames["pos"])
if "y" in self.properties:
os.rename(y_mm.filename, self.fname("y"))
os.rename(y_mm.filename, fnames["y"])
if "neg_dy" in self.properties:
os.rename(neg_dy_mm.filename, self.fname("neg_dy"))
os.rename(neg_dy_mm.filename, fnames["neg_dy"])
if "q" in self.properties:
os.rename(q_mm.filename, self.fname("q"))
os.rename(q_mm.filename, fnames["q"])
if "pq" in self.properties:
os.rename(pq_mm.filename, self.fname("pq"))
os.rename(pq_mm.filename, fnames["pq"])
if "dp" in self.properties:
os.rename(dp_mm.filename, self.fname("dp"))
os.rename(dp_mm.filename, fnames["dp"])

def len(self):
return len(self.idx_mm) - 1
Expand Down

0 comments on commit 07424e1

Please sign in to comment.