-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmonitor.py
127 lines (118 loc) · 4.49 KB
/
monitor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import time
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
from alms.database import *
from alms.args import MonitorArgs
from alms.ml.mgk.args import ActiveLearningArgs
from alms.ml.mgk.kernels.utils import get_kernel_config
from alms.ml.mgk.data.data import Dataset
from alms.ml.mgk.evaluator import ActiveLearner
def active_learning(margs: MonitorArgs):
if margs.stop_uncertainty is None:
return
if margs.stop_uncertainty < 0.0:
mols_all = session.query(Molecule)
for mol in tqdm(mols_all.all(), total=mols_all.count()):
mol.active = True
mol.inactive = False
session.commit()
return
args = ActiveLearningArgs()
args.dataset_type = 'regression'
args.save_dir = 'data'
args.n_jobs = margs.n_jobs
args.pure_columns = ['smiles']
args.target_columns = ['target']
args.graph_kernel_type = margs.graph_kernel_type
args.graph_hyperparameters = ['data/tMGR.json']
args.model_type = 'gpr'
args.alpha = 0.01
args.optimizer = None
args.batch_size = None
args.learning_algorithm = 'unsupervised'
args.add_size = 1
args.cluster_size = 1
args.pool_size = margs.pool_size
args.stop_size = 100000
args.stop_uncertainty = [margs.stop_uncertainty]
args.evaluate_stride = 0
args.seed = margs.seed
mols_all = session.query(Molecule).filter_by(fail=False)
# random select 2 samples as the start of active learning.
if mols_all.filter_by(active=True).count() <= 1:
for mol in np.random.choice(mols_all.all(), 2, replace=False):
mol.active = True
mol.inactive = False
session.commit()
# get selected data set.
mols = mols_all.filter_by(active=True)
df = pd.DataFrame({'smiles': [mol.smiles for mol in mols],
'target': [0.] * mols.count()})
dataset = Dataset.from_df(args, df)
dataset.update_args(args)
# get pool data set.
mols = mols_all.filter_by(active=False, inactive=False).limit(50000)
if mols.count() == 0:
return
df_pool = pd.DataFrame({'smiles': [mol.smiles for mol in mols],
'target': [0.] * mols.count()})
dataset_pool = Dataset.from_df(args, df_pool)
dataset_pool.update_args(args)
# get full data set.
dataset_full = dataset.copy()
dataset_full.data = dataset.data + dataset_pool.data
dataset_full.unify_datatype(dataset_full.X_graph)
#
kernel_config = get_kernel_config(args, dataset_full, kernel_pkl=os.path.join(args.save_dir, 'kernel.pkl'))
# active learning
al = ActiveLearner(args, dataset, dataset_pool, kernel_config, kernel_config)
al.run()
if len(al.dataset) != 0:
smiles = [s.split(',')[0] for s in al.dataset.X_repr.ravel()]
for mol in tqdm(mols_all.all(), total=mols_all.count()):
if mol.smiles in smiles:
mol.active = True
mol.inactive = False
if len(al.dataset_pool) != 0:
smiles = [s.split(',')[0] for s in al.dataset_pool.X_repr.ravel()]
for mol in tqdm(mols_all.all(), total=mols_all.count()):
if mol.smiles in smiles:
mol.active = False
mol.inactive = True
session.commit()
def monitor(args: MonitorArgs):
if args.task == 'qm_cv':
from alms.qm.qm_cv import get_GaussianSimulator, create, build, run, analyze, extend
from alms.analysis.cp import update_fail_mols
simulator = get_GaussianSimulator(args)
elif args.task == 'md_npt':
from alms.md.md_npt import get_NptSimulator, create, build, run, analyze, extend
from alms.analysis.cp import update_fail_mols
simulator = get_NptSimulator(args)
else:
return
job_manager = args.job_manager_
while True:
print('Start a new loop\n'
'Step1: active learning.\n\n')
active_learning(args)
print('Step2: create.\n\n')
create(args)
print('\nStep3: build.\n')
build(args, simulator)
print('\nStep4: run.\n')
run(args, simulator, job_manager)
print('\nStep5: analyze.\n')
analyze(args, simulator, job_manager)
print('\nStep6: extend.\n')
extend(args, simulator, job_manager)
print('\nStep7: update failed mols.\n')
update_fail_mols()
print('Sleep %d minutes...' % args.t_sleep)
time.sleep(args.t_sleep * 60)
if __name__ == '__main__':
monitor(args=MonitorArgs().parse_args())