-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_predictions.py
97 lines (83 loc) · 3.52 KB
/
make_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from argparse import Namespace
import csv
from typing import List, Optional
#import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from .predict import predict
from chemprop.data import MoleculeDataset
from chemprop.data.utils import get_data, get_data_from_smiles
from chemprop.utils import load_args, load_checkpoint, load_scalers
def make_predictions(args: Namespace, smiles: List[str] = None) -> List[Optional[List[float]]]:
"""
Makes predictions. If smiles is provided, makes predictions on smiles. Otherwise makes predictions on args.test_data.
:param args: Arguments.
:param smiles: Smiles to make predictions on.
:return: A list of lists of target predictions.
"""
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
print('Loading training args')
scaler, features_scaler = load_scalers(args.checkpoint_paths[0])
train_args = load_args(args.checkpoint_paths[0])
# Update args with training arguments
for key, value in vars(train_args).items():
if not hasattr(args, key):
setattr(args, key, value)
print('Loading data')
if smiles is not None:
test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False)
else:
test_data = get_data(path=args.test_path, args=args, use_compound_names=args.use_compound_names, skip_invalid_smiles=False)
print('Validating SMILES')
valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None]
full_data = test_data
test_data = MoleculeDataset([test_data[i] for i in valid_indices])
# Edge case if empty list of smiles is provided
if len(test_data) == 0:
return [None] * len(full_data)
if args.use_compound_names:
compound_names = test_data.compound_names()
print(f'Test size = {len(test_data):,}')
# Normalize features
if train_args.features_scaling:
test_data.normalize_features(features_scaler)
# Predict with each model individually and sum predictions
if args.dataset_type == 'multiclass':
sum_preds = np.zeros((len(test_data), args.num_tasks, args.multiclass_num_classes))
else:
sum_preds = np.zeros((len(test_data), args.num_tasks))
# all_preds = np.empty((len(test_data), args.num_tasks, 1))
print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models')
#
# Modifications are made by RL to calculate standard deviations of all individual model predictions of an ensemble of models
#
iliu = 0
for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)):
# Load model
iliu +=1
model = load_checkpoint(checkpoint_path, cuda=args.cuda)
model_preds = predict(
model=model,
data=test_data,
batch_size=args.batch_size,
scaler=scaler
)
sum_preds += np.array(model_preds)
m_preds =np.array(model_preds)
mm_preds = m_preds[:, :, np.newaxis]
# df = pd.DataFrame({'smiles':test_data.smiles()})
# for i in range(len(m_preds[0])):
# df[f'pred_{i}'] = [item[i] for item in m_preds]
# df.to_csv(f'./pred_out_{iliu}.csv', index=False)
if iliu==1:
all_preds=mm_preds
else:
all_preds=np.concatenate((all_preds, mm_preds), axis=2)
# Ensemble predictions
avg_preds = sum_preds / len(args.checkpoint_paths)
avg_preds = avg_preds.tolist()
std_preds = np.std(all_preds, axis=2)
std_preds =std_preds.tolist()
return avg_preds, test_data.smiles(), std_preds