|
| 1 | +#!/usr/bin/env/python |
| 2 | +""" |
| 3 | +Usage: |
| 4 | + make_dataset.py [options] |
| 5 | +
|
| 6 | +Options: |
| 7 | + -h --help Show this screen. |
| 8 | + --dataset NAME QM9 or ZINC |
| 9 | +""" |
| 10 | + |
| 11 | +import json |
| 12 | +import os |
| 13 | +import sys |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +from docopt import docopt |
| 17 | +from rdkit import Chem |
| 18 | +from rdkit.Chem import QED |
| 19 | + |
| 20 | +import utils |
| 21 | + |
| 22 | +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) |
| 23 | + |
| 24 | +# get current directory in order to work with full path and not dynamic |
| 25 | +current_dir = os.path.dirname(os.path.realpath(__file__)) |
| 26 | + |
| 27 | + |
| 28 | +def readStr_qm9(): |
| 29 | + f = open(current_dir + '/qm9.smi', 'r') |
| 30 | + L = [] |
| 31 | + for line in f: |
| 32 | + line = line.strip() |
| 33 | + L.append(line) |
| 34 | + f.close() |
| 35 | + np.random.seed(1) |
| 36 | + np.random.shuffle(L) |
| 37 | + return L |
| 38 | + |
| 39 | + |
| 40 | +def read_zinc(): |
| 41 | + f = open(current_dir + '/zinc.smi', 'r') |
| 42 | + L = [] |
| 43 | + for line in f: |
| 44 | + line = line.strip() |
| 45 | + L.append(line) |
| 46 | + f.close() |
| 47 | + return L |
| 48 | + |
| 49 | + |
| 50 | +def train_valid_split(dataset): |
| 51 | + n_mol_out = 0 |
| 52 | + n_test = 5000 |
| 53 | + test_idx = np.arange(0, n_test) |
| 54 | + valid_idx = np.random.randint(n_test, high=len(dataset), size=round(len(dataset) * 0.1)) |
| 55 | + |
| 56 | + # save the train, valid dataset. |
| 57 | + raw_data = {'train': [], 'valid': [], 'test': []} |
| 58 | + file_count = 0 |
| 59 | + for i, smiles in enumerate(dataset): |
| 60 | + val = QED.qed(Chem.MolFromSmiles(smiles)) |
| 61 | + hist = make_hist(smiles) |
| 62 | + if hist is not None: |
| 63 | + if i in valid_idx: |
| 64 | + raw_data['valid'].append({'smiles': smiles, 'QED': val, 'hist': hist.tolist()}) |
| 65 | + elif i in test_idx: |
| 66 | + raw_data['test'].append({'smiles': smiles, 'QED': val, 'hist': hist.tolist()}) |
| 67 | + else: |
| 68 | + raw_data['train'].append({'smiles': smiles, 'QED': val, 'hist': hist.tolist()}) |
| 69 | + file_count += 1 |
| 70 | + if file_count % 1000 == 0: |
| 71 | + print('Finished reading: %d' % file_count, end='\r') |
| 72 | + else: |
| 73 | + n_mol_out += 1 |
| 74 | + |
| 75 | + print("Number of molecules left out: ", n_mol_out) |
| 76 | + return raw_data |
| 77 | + |
| 78 | + |
| 79 | +def make_hist(smiles): |
| 80 | + mol = Chem.MolFromSmiles(smiles) |
| 81 | + atoms = mol.GetAtoms() |
| 82 | + hist = np.zeros(utils.dataset_info(dataset)['hist_dim']) |
| 83 | + for atom in atoms: |
| 84 | + if dataset == 'qm9': |
| 85 | + atom_str = atom.GetSymbol() |
| 86 | + else: |
| 87 | + # zinc dataset # transform using "<atom_symbol><valence>(<charge>)" notation |
| 88 | + symbol = atom.GetSymbol() |
| 89 | + valence = atom.GetTotalValence() |
| 90 | + charge = atom.GetFormalCharge() |
| 91 | + atom_str = "%s%i(%i)" % (symbol, valence, charge) |
| 92 | + |
| 93 | + if atom_str not in utils.dataset_info(dataset)['atom_types']: |
| 94 | + print('Unrecognized atom type %s' % atom_str) |
| 95 | + return None |
| 96 | + |
| 97 | + ind = utils.dataset_info(dataset)['atom_types'].index(atom_str) |
| 98 | + val = utils.dataset_info(dataset)['maximum_valence'][ind] |
| 99 | + hist[val - 1] += 1 # in the array the valence number start from 1, instead the array start from 0 |
| 100 | + return hist |
| 101 | + |
| 102 | + |
| 103 | +def preprocess(raw_data, dataset): |
| 104 | + print('Parsing smiles as graphs...') |
| 105 | + processed_data = {'train': [], 'valid': [], 'test': []} |
| 106 | + |
| 107 | + file_count = 0 |
| 108 | + for section in ['train', 'valid', 'test']: |
| 109 | + all_smiles = [] # record all smiles in training dataset |
| 110 | + for i, (smiles, QED, hist) in enumerate([(mol['smiles'], mol['QED'], mol['hist']) |
| 111 | + for mol in raw_data[section]]): |
| 112 | + nodes, edges = utils.to_graph(smiles, dataset) |
| 113 | + if len(edges) <= 0: |
| 114 | + print('Error. Molecule with len(edges) <= 0') |
| 115 | + continue |
| 116 | + processed_data[section].append({ |
| 117 | + 'targets': [[QED]], |
| 118 | + 'graph': edges, |
| 119 | + 'node_features': nodes, |
| 120 | + 'smiles': smiles, |
| 121 | + 'hist': hist |
| 122 | + }) |
| 123 | + all_smiles.append(smiles) |
| 124 | + if file_count % 1000 == 0: |
| 125 | + print('Finished processing: %d' % file_count, end='\r') |
| 126 | + file_count += 1 |
| 127 | + print('%s: 100 %% ' % (section)) |
| 128 | + with open('molecules_%s_%s.json' % (section, dataset), 'w') as f: |
| 129 | + json.dump(processed_data[section], f) |
| 130 | + |
| 131 | + print("Train molecules = " + str(len(processed_data['train']))) |
| 132 | + print("Valid molecules = " + str(len(processed_data['valid']))) |
| 133 | + print("Test molecules = " + str(len(processed_data['test']))) |
| 134 | + |
| 135 | + |
| 136 | +if __name__ == "__main__": |
| 137 | + args = docopt(__doc__) |
| 138 | + dataset = args.get('--dataset') |
| 139 | + |
| 140 | + print('Reading dataset: ' + str(dataset)) |
| 141 | + data = [] |
| 142 | + if dataset == 'qm9': |
| 143 | + data = readStr_qm9() |
| 144 | + elif dataset == 'zinc': |
| 145 | + data = read_zinc() |
| 146 | + else: |
| 147 | + print('Error. The database doesn\'t exist') |
| 148 | + exit(1) |
| 149 | + |
| 150 | + raw_data = train_valid_split(data) |
| 151 | + preprocess(raw_data, dataset) |
0 commit comments