Skip to content

Commit db08e15

Browse files
committed
Add initial code.
1 parent 6556e3c commit db08e15

12 files changed

+3568
-0
lines changed

.gitignore

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
cluster/*
2+
3+
.idea/*
4+
5+
data/*.smi
6+
data/*.json
7+
data/*.pkl
8+
9+
histogramAnalysis/*.png
10+
11+
utils/fpscores.pkl.gz
12+
utils/sascorer.py
13+
14+
results/*
15+
16+
*__*__/*
17+
*.pyc
18+
tmp/*

CCGVAE.py

+1,563
Large diffs are not rendered by default.

ccgvae_env.yml

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
name: ccgvae
2+
channels:
3+
- rdkit
4+
- defaults
5+
dependencies:
6+
- _tflow_select=2.1.0=gpu
7+
- absl-py=0.4.1=py35_0
8+
- astor=0.7.1=py35_0
9+
- blas=1.0=mkl
10+
- bzip2=1.0.6=h14c3975_5
11+
- ca-certificates=2018.03.07=0
12+
- cairo=1.14.12=h8948797_3
13+
- certifi=2018.8.24=py35_1
14+
- cudatoolkit=9.2=0
15+
- cudnn=7.3.1=cuda9.2_0
16+
- cupti=9.2.148=0
17+
- fontconfig=2.13.0=h9420a91_0
18+
- freetype=2.9.1=h8a8886c_1
19+
- gast=0.2.0=py35_0
20+
- glib=2.56.2=hd408876_0
21+
- grpcio=1.12.1=py35hdbcaa40_0
22+
- icu=58.2=h9c2bf20_1
23+
- intel-openmp=2019.1=144
24+
- jpeg=9b=h024ee3a_2
25+
- libboost=1.65.1=habcd387_4
26+
- libedit=3.1.20170329=h6b74fdf_2
27+
- libffi=3.2.1=hd88cf55_4
28+
- libgcc-ng=8.2.0=hdf63c60_1
29+
- libgfortran-ng=7.3.0=hdf63c60_0
30+
- libpng=1.6.35=hbc83047_0
31+
- libprotobuf=3.6.0=hdbcaa40_0
32+
- libstdcxx-ng=8.2.0=hdf63c60_1
33+
- libtiff=4.0.9=he85c1e1_2
34+
- libuuid=1.0.3=h1bed415_2
35+
- libxcb=1.13=h1bed415_1
36+
- libxml2=2.9.8=h26e45fe_1
37+
- markdown=2.6.11=py35_0
38+
- mkl=2018.0.3=1
39+
- mkl_fft=1.0.6=py35h7dd41cf_0
40+
- mkl_random=1.0.1=py35h4414c95_1
41+
- ncurses=6.1=he6710b0_1
42+
- numpy=1.14.5
43+
- numpy-base=1.14.5
44+
- olefile=0.46=py35_0
45+
- openssl=1.0.2p=h14c3975_0
46+
- pandas=0.23.4=py35h04863e7_0
47+
- pcre=8.42=h439df22_0
48+
- pillow=5.2.0=py35heded4f4_0
49+
- pip=10.0.1=py35_0
50+
- pixman=0.34.0=hceecf20_3
51+
- protobuf=3.6.0=py35hf484d3e_0
52+
- py-boost=1.65.1=py35hf484d3e_4
53+
- python=3.5.6=hc3d631a_0
54+
- python-dateutil=2.7.3=py35_0
55+
- pytz=2018.5=py35_0
56+
- readline=7.0=h7b6447c_5
57+
- setuptools=39.1.0
58+
- six=1.11.0=py35_1
59+
- sqlite=3.25.3=h7b6447c_0
60+
- tensorboard=1.10.0=py35hf484d3e_0
61+
- tensorflow=1.10.0=gpu_py35hd9c640d_0
62+
- tensorflow-base=1.10.0=gpu_py35had579c0_0
63+
- tensorflow-gpu=1.10.0=hf154084_0
64+
- termcolor=1.1.0=py35_1
65+
- tk=8.6.8=hbc83047_0
66+
- werkzeug=0.14.1=py35_0
67+
- wheel=0.31.1=py35_0
68+
- xz=5.2.4=h14c3975_4
69+
- zlib=1.2.11=h7b6447c_3
70+
- rdkit=2018.03.4.0=py35h71b666b_1
71+
- pip:
72+
- bleach==1.5.0
73+
- cython==0.29.1
74+
- docopt==0.6.2
75+
- html5lib==0.9999999
76+
- mkl-fft==1.0.6
77+
- mkl-random==1.0.1
78+
- planarity==0.4.1
79+
- tensorflow-tensorboard==0.1.8
80+
- typing==3.6.6
81+
prefix: /home/user/Programs/miniconda/envs/givae
82+

ccgvae_env_requirements.txt

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
absl-py==0.4.1
2+
astor==0.7.1
3+
bleach==1.5.0
4+
certifi==2018.8.24
5+
Cython==0.29.1
6+
docopt==0.6.2
7+
gast==0.2.0
8+
grpcio==1.12.1
9+
html5lib==0.9999999
10+
Markdown==3.0.1
11+
mkl-fft==1.0.6
12+
mkl-random==1.0.1
13+
olefile==0.46
14+
pandas==0.23.4
15+
Pillow==5.2.0
16+
planarity==0.4.1
17+
protobuf==3.6.1
18+
python-dateutil==2.7.3
19+
pytz==2018.5
20+
six==1.11.0
21+
tensorboard==1.10.0
22+
tensorflow==1.10.0
23+
termcolor==1.1.0
24+
typing==3.6.6
25+
Werkzeug==0.14.1

data/make_dataset.py

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#!/usr/bin/env/python
2+
"""
3+
Usage:
4+
make_dataset.py [options]
5+
6+
Options:
7+
-h --help Show this screen.
8+
--dataset NAME QM9 or ZINC
9+
"""
10+
11+
import json
12+
import os
13+
import sys
14+
15+
import numpy as np
16+
from docopt import docopt
17+
from rdkit import Chem
18+
from rdkit.Chem import QED
19+
20+
import utils
21+
22+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
23+
24+
# get current directory in order to work with full path and not dynamic
25+
current_dir = os.path.dirname(os.path.realpath(__file__))
26+
27+
28+
def readStr_qm9():
29+
f = open(current_dir + '/qm9.smi', 'r')
30+
L = []
31+
for line in f:
32+
line = line.strip()
33+
L.append(line)
34+
f.close()
35+
np.random.seed(1)
36+
np.random.shuffle(L)
37+
return L
38+
39+
40+
def read_zinc():
41+
f = open(current_dir + '/zinc.smi', 'r')
42+
L = []
43+
for line in f:
44+
line = line.strip()
45+
L.append(line)
46+
f.close()
47+
return L
48+
49+
50+
def train_valid_split(dataset):
51+
n_mol_out = 0
52+
n_test = 5000
53+
test_idx = np.arange(0, n_test)
54+
valid_idx = np.random.randint(n_test, high=len(dataset), size=round(len(dataset) * 0.1))
55+
56+
# save the train, valid dataset.
57+
raw_data = {'train': [], 'valid': [], 'test': []}
58+
file_count = 0
59+
for i, smiles in enumerate(dataset):
60+
val = QED.qed(Chem.MolFromSmiles(smiles))
61+
hist = make_hist(smiles)
62+
if hist is not None:
63+
if i in valid_idx:
64+
raw_data['valid'].append({'smiles': smiles, 'QED': val, 'hist': hist.tolist()})
65+
elif i in test_idx:
66+
raw_data['test'].append({'smiles': smiles, 'QED': val, 'hist': hist.tolist()})
67+
else:
68+
raw_data['train'].append({'smiles': smiles, 'QED': val, 'hist': hist.tolist()})
69+
file_count += 1
70+
if file_count % 1000 == 0:
71+
print('Finished reading: %d' % file_count, end='\r')
72+
else:
73+
n_mol_out += 1
74+
75+
print("Number of molecules left out: ", n_mol_out)
76+
return raw_data
77+
78+
79+
def make_hist(smiles):
80+
mol = Chem.MolFromSmiles(smiles)
81+
atoms = mol.GetAtoms()
82+
hist = np.zeros(utils.dataset_info(dataset)['hist_dim'])
83+
for atom in atoms:
84+
if dataset == 'qm9':
85+
atom_str = atom.GetSymbol()
86+
else:
87+
# zinc dataset # transform using "<atom_symbol><valence>(<charge>)" notation
88+
symbol = atom.GetSymbol()
89+
valence = atom.GetTotalValence()
90+
charge = atom.GetFormalCharge()
91+
atom_str = "%s%i(%i)" % (symbol, valence, charge)
92+
93+
if atom_str not in utils.dataset_info(dataset)['atom_types']:
94+
print('Unrecognized atom type %s' % atom_str)
95+
return None
96+
97+
ind = utils.dataset_info(dataset)['atom_types'].index(atom_str)
98+
val = utils.dataset_info(dataset)['maximum_valence'][ind]
99+
hist[val - 1] += 1 # in the array the valence number start from 1, instead the array start from 0
100+
return hist
101+
102+
103+
def preprocess(raw_data, dataset):
104+
print('Parsing smiles as graphs...')
105+
processed_data = {'train': [], 'valid': [], 'test': []}
106+
107+
file_count = 0
108+
for section in ['train', 'valid', 'test']:
109+
all_smiles = [] # record all smiles in training dataset
110+
for i, (smiles, QED, hist) in enumerate([(mol['smiles'], mol['QED'], mol['hist'])
111+
for mol in raw_data[section]]):
112+
nodes, edges = utils.to_graph(smiles, dataset)
113+
if len(edges) <= 0:
114+
print('Error. Molecule with len(edges) <= 0')
115+
continue
116+
processed_data[section].append({
117+
'targets': [[QED]],
118+
'graph': edges,
119+
'node_features': nodes,
120+
'smiles': smiles,
121+
'hist': hist
122+
})
123+
all_smiles.append(smiles)
124+
if file_count % 1000 == 0:
125+
print('Finished processing: %d' % file_count, end='\r')
126+
file_count += 1
127+
print('%s: 100 %% ' % (section))
128+
with open('molecules_%s_%s.json' % (section, dataset), 'w') as f:
129+
json.dump(processed_data[section], f)
130+
131+
print("Train molecules = " + str(len(processed_data['train'])))
132+
print("Valid molecules = " + str(len(processed_data['valid'])))
133+
print("Test molecules = " + str(len(processed_data['test'])))
134+
135+
136+
if __name__ == "__main__":
137+
args = docopt(__doc__)
138+
dataset = args.get('--dataset')
139+
140+
print('Reading dataset: ' + str(dataset))
141+
data = []
142+
if dataset == 'qm9':
143+
data = readStr_qm9()
144+
elif dataset == 'zinc':
145+
data = read_zinc()
146+
else:
147+
print('Error. The database doesn\'t exist')
148+
exit(1)
149+
150+
raw_data = train_valid_split(data)
151+
preprocess(raw_data, dataset)

0 commit comments

Comments
 (0)