-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpredict.py
99 lines (88 loc) · 3.35 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python
"""
Script for predicting TF binding with a trained model.
Use `predict.py -h` to see an auto-generated description of advanced options.
"""
import numpy as np
import pylab
import matplotlib
import pandas
import utils
import pickle
# Standard library imports
import sys
import os
import errno
import argparse
def make_argument_parser():
"""
Creates an ArgumentParser to read the options for this script from
sys.argv
"""
parser = argparse.ArgumentParser(
description="Generate predictions from a trained model.",
epilog='\n'.join(__doc__.strip().split('\n')[1:]).strip(),
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--inputdir', '-i', type=str, required=True,
help='Folder containing input data')
parser.add_argument('--modeldir', '-m', type=str, required=True,
help='Folder containing trained model generated by train.py.')
parser.add_argument('--factor', '-f', type=str, required=True,
help='The transcription factor to evaluate.')
parser.add_argument('--bed', '-b', type=str, required=True,
help='Sorted BED file containing intervals to predict on.')
parser.add_argument('--outputfile', '-o', type=str, required=True,
help='The output filename.')
return parser
def main():
"""
The main executable function
"""
parser = make_argument_parser()
args = parser.parse_args()
input_dir = args.inputdir
model_dir = args.modeldir
tf = args.factor
bed_file = args.bed
output_file = args.outputfile
print 'Loading genome'
genome = utils.load_genome()
print 'Loading model'
model_tfs, model_bigwig_names, features, model = utils.load_model(model_dir)
L = model.input_shape[0][1]
utils.L = L
assert tf in model_tfs
assert 'bigwig' in features
use_meta = 'meta' in features
use_gencode = 'gencode' in features
print 'Loading test data'
is_sorted = True
bigwig_names, meta_names, datagen_bed, nonblacklist_bools = utils.load_beddata(genome, bed_file, use_meta, use_gencode, input_dir, is_sorted)
assert bigwig_names == model_bigwig_names
if use_meta:
model_meta_file = model_dir + '/meta.txt'
assert os.path.isfile(model_meta_file)
model_meta_names = np.loadtxt(model_meta_file, dtype=str)
if len(model_meta_names.shape) == 0:
model_meta_names = [str(model_meta_names)]
else:
model_meta_names = list(model_meta_names)
assert meta_names == model_meta_names
print 'Generating predictions'
model_tf_index = model_tfs.index(tf)
model_predicts = model.predict_generator(datagen_bed, val_samples=len(datagen_bed), pickle_safe=True)
if len(model_tfs) > 1:
model_tf_predicts = model_predicts[:, model_tf_index]
else:
model_tf_predicts = model_predicts
final_scores = np.zeros(len(nonblacklist_bools))
final_scores[nonblacklist_bools] = model_tf_predicts
print 'Saving predictions'
df = pandas.read_csv(bed_file, sep='\t', header=None)
df[3] = final_scores
df.to_csv(output_file, sep='\t', compression='gzip', float_format='%.3e', header=False, index=False)
if __name__ == '__main__':
"""
See module-level docstring for a description of the script.
"""
main()