-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathextract_contexts.py
304 lines (279 loc) · 15.4 KB
/
extract_contexts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
from collections import defaultdict
from Bio import SeqIO
import scipy.stats as ss
import numpy as np
#import cPickle
import pickle #_pickle as cPickle
import sys
import re
import os
base_comps = {'A':'T','C':'G','T':'A','G':'C','N':'N','M':'M'}
#@profile
def comp(seq,base_comps=base_comps):
return ''.join([base_comps[nt] for nt in list(seq)])
#@profile
def revcomp(seq,rev=True):
if not rev:
return seq
else:
return ''.join(list(comp(seq))[::-1])
#@profile
def strand(rev):
if rev:
return '-'
else:
return '+'
#find positions of motifs (eg. CG bases) in reference sequence and change to M
#@profile
def methylate_motifs(ref_seq,motif,meth_base,meth_position=None):
if meth_position:
meth_motif = motif[:meth_position]+'M'
if meth_position < len(motif)-1:
meth_motif = meth_position+motif[meth_position+1:]
else:
meth_motif = 'M'.join(motif.split(meth_base))
meth_seq = ref_seq.replace(motif,meth_motif)
return meth_seq
#change specified positions to M in reference sequence
#@profile
def methylate_positions(ref_seq,positions,meth_base):
meth_seq = ref_seq
count = 0
for pos in positions: #changed to 0-based - else have to subtract from pos
if meth_seq[pos] == meth_base or meth_seq[pos] == 'M':
meth_seq = meth_seq[:pos]+'M'+meth_seq[pos+1:]
count+=1
else:
print('Base {} does not correspond to methylated base - check reference positions are 0-based - quitting thread now'.format(pos))
sys.exit(0)
#print count, 'positions methylated in one strand'
return meth_seq
#extract signals around methylated positions from tsv
#@profile
def methylate_references(ref_seq,base,motif=None,positions=None,train=False,contig=None):
if not positions and motif:
meth_fwd = methylate_motifs(ref_seq,motif,base)
meth_rev = methylate_motifs(ref_seq,revcomp(motif),base_comps[base])
#print len(meth_fwd.split('M')),'Ms in methylated sequence'
elif positions:
fwd_pos = [int(pos.split()[1]) for pos in open(positions,'r').read().split('\n') if len(pos.split()) > 1 and pos.split()[2] == '+' and pos.split()[0] == contig]
rev_pos = [int(pos.split()[1]) for pos in open(positions,'r').read().split('\n') if len(pos.split()) > 1 and pos.split()[2] == '-' and pos.split()[0] == contig]
meth_fwd = methylate_positions(ref_seq,fwd_pos,base)
meth_rev = methylate_positions(ref_seq,rev_pos,base_comps[base])
else:
print('no motifs or positions specified')
sys.exit(0)
return meth_fwd,meth_rev
#@profile
def find_and_methylate(refname,contigname,base,motif,positions_list):
for ref in SeqIO.parse(refname,"fasta"):
contigid = ref.id
if contigid == contigname:
meth_fwd,meth_rev = methylate_references(str(ref.seq).upper(),base,motif=motif,positions=positions_list,contig=contigname)
return meth_fwd,meth_rev
def writefi(data,fi):
with open(fi,'a') as outfi:
for entry in data:
outfi.write('\t'.join(entry)+'\n')
def adjust_scores(context_dict,context,diffs,prob,k):
if context in context_dict['m6A']:
hmm_score = 1-(1/np.prod([ss.norm(context_dict['m6A'][context]['mean'][i],context_dict['m6A'][context]['sd'][i]).pdf(diffs[i]) for i in range(k)]))
correlation_score = ss.stats.pearsonr(context_dict['m6A'][context]['mean'],diffs)[0]
if context in context_dict['A']:
correlation_diff = correlation_score - ss.stats.pearsonr(context_dict['A'][context]['mean'],diffs)[0]
frac_meth = context_dict['m6A'][context]['num']*1./context_dict['A'][context]['num']
else:
frac_meth = 1
representation_score = prob + 1 - frac_meth #increases score for contexts not included in methylation training set
def base_models(base,twobase=False):
if base == 'A' and twobase:
base_model = {'MG':'MG','MC':'MH','MA':'MH','MT':'MH','MM':'MH','MH':'MH','AT':'MH','AC':'MH','AG':'MG','AT':'MH','AA':'MH','AM':'MH'} #TODO: fix error where sites not methylated
else:
base_model = {'M'+nextb:'general' for nextb in ['A','C','G','T','M']}
base_model.update({'A'+nextb:'general' for nextb in ['A','C','G','T','M']})
base_model.update({'T'+nextb:'general' for nextb in ['A','C','G','T','M']}) #shouldn't be necessary
return(base_model)
#determine difference between measurements and model for bases surrounding methylated positions
#@profile
def extract_features(tsv_input,fasta_input,read2qual,k,skip_thresh,qual_thresh,modelfile,classifier,startline,endline=None,train=False,pos_label=None,base=None,motif=None,positions_list=None):
#set position variables
last_read,last_pos,last_pos_in_kmer,last_read_num = '',0,k,0
last_contig = None
#set count variables
num_observations,w_skips,skipped_skips,pos_set,multi_meth_pos_set,read_set = 0,set(),set(),set(),set(),set()
#set tracking variables for observation
mpos = None
diff_col = [[] for xi in range(k)]
if not train:
tsv_output = '.'.join(tsv_input.split('.')[:-1])+'.diffs.'+str(k)+'.tmp'+str(startline)
modfi = open(modelfile,'rb')
model = pickle.load(modfi,encoding='latin')
modfi.close()
if type(model) != dict:
model = {'general':model} #for compatibility with previously trained model
twobase = False
else:
twobase = True
base_model = base_models(base,twobase)
else:
base_model = base_models(base,False) #or set to False?
tsv_output = '.'.join(tsv_input.split('.')[:-1])+'.diffs.'+str(k)+'.train.tmp'+str(startline)
signals,contexts = {bm:{} for bm in base_model.values()},{bm:{} for bm in base_model.values()}
towrite = []
#save only one set of adjoining methylated positions at a time - once the set complete, write the positions to a file
#tsv format: ecoli 805 CGCCAT cc1da58e-3db3-4a4b-93c2-c78e1dbe6aba:1D_000:template t 1 102.16 0.963 0.00175 CGCCAT 102.23 1.93 -0.03 101.973,100.037,102.403,101.758,104.338,102.618,101.973
with open(tsv_input,'r') as tsv:
tsv.seek(max(startline-500,0))
linepos = max(startline-500,0)
#startline, endline, and linepos are in characters -- previously used tsv.tell(), but incompatible with python3
while linepos <= endline-500:
#print('current position',linepos)
lines = tsv.readlines(8000000) #TODO: why 8M? reasonable size for memory consumption, but could change
for line in lines:
linepos += len(line)
try:
chrom, read_pos, read_kmer, read_name, x, read_ind, event_current, event_sd, y, ref_kmer, model_current, ref_sd = line.split()[:12]
except ValueError:
continue
if chrom != last_contig:
try:
meth_fwd,meth_rev = find_and_methylate(fasta_input,chrom,base,motif,positions_list)
last_contig = chrom
except TypeError: #ValueError
print('Error: could not find sequence for reference contig',chrom)
continue
if read_name != last_read:
first_read_ind = int(read_ind)
try:
qual = read2qual[read_name]
except KeyError:
qual = read2qual[read_name.split(':')[0].split('_')[0]]
if (qual < qual_thresh) or ref_kmer == 'NNNNNN':
continue
if (read_name != last_read and read_kmer == ref_kmer) or (read_name == last_read and int(read_ind) > first_read_ind): #takes into account complementary palindromes - temporarily sets new reads to positive strand
rev = False
meth_ref = meth_fwd
else:
rev = True
meth_ref = meth_rev
read_pos = int(read_pos)
reference_kmer = meth_ref[read_pos:read_pos+k]
#if finished context for previous potentially modified position, save and reset
if mpos and ((read_pos >= mpos+1 and read_name == last_read) or (read_name != last_read)):
#write to file
num_skips = len([x for x in diff_col if x == []])
if num_skips <= skip_thresh: #accept max number of skips within an observation
if num_skips > 0:
w_skips.add((last_read,mpos))
diffs = [np.mean(kmer_pos) if kmer_pos!=[] else 0 for kmer_pos in diff_col]
if not last_rev:
diffs = diffs[::-1]
try:
last_qual = read2qual[last_read]
except KeyError:
last_qual = read2qual[last_read.split(':')[0].split('_')[0]]
diffs = diffs+[last_qual]
context = revcomp(last_ref[mpos-k+1:mpos+k],last_rev)
if context[int(len(context)/2)] == 'M':
try:
twobase_model = base_model[context[int(len(context)/2):int(len(context)/2)+2]]
if not train:
mod_prob = model[twobase_model].predict_proba([diffs]) #TODO: call model only when batch ready to write
if mod_prob[0][1] >= 0.5:
if base == 'A':
label = 'm6A' #TODO: ensure correct direction + label unmeth/meth as appropriate
else:
label = 'm'+base
else:
label = base
label = label+'\t'+str(np.round(mod_prob[0][1],2))
else:
mod_prob = ''
label = pos_label[(chrom,mpos,strand(last_rev))]
if label not in signals[twobase_model]:
signals[twobase_model][label] = []
contexts[twobase_model][label] = []
signals[twobase_model][label].append(diffs)
contexts[twobase_model][label].append(context)
towrite.append([chrom,last_read,str(mpos),context,','.join([str(diff) for diff in diffs]),strand(last_rev),label])
last_info = last_read+'\t'+str(mpos)+'\t'+context+'\t'+','.join([str(diff) for diff in diffs])+'\t'+strand(last_rev)
except (IndexError,KeyError) as e:
print(last_read+'\t'+str(mpos)+'\t'+context+'\t'+','.join([str(diff) for diff in diffs])+'\t'+strand(last_rev),'- Index or Key Error')
print(model.keys(), base_model.keys(), context[int(len(context)/2):int(len(context)/2)+2])
print(e)
print(model[twobase_model].predict_proba([diffs]))
sys.exit(0)
else:
print(last_read+'\t'+str(mpos)+'\t'+context+'\t'+','.join([str(diff) for diff in diffs])+'\t'+strand(last_rev))
print(read_name, rev, last_read, last_rev, last_first)
print(read_kmer,reference_kmer, ref_kmer, last_pos_in_kmer, mspacing, pos_in_kmer)
sys.exit(0)
num_observations += 1
if num_observations%5000 == 0:
writefi(towrite,tsv_output)
towrite = []
pos_set.add(mpos)
read_set.add(last_read)
if len(read_set)%1000 == 0 and len(read_set) > last_read_num:
#print len(read_set), 'reads examined'
last_read_num = len(read_set)
else:
skipped_skips.add((last_read,mpos))
#reset variables
if len(reference_kmer.split('M')) < 2 or read_name != last_read or read_pos > mpos+skip_thresh+1: #allow no more than skip_thresh skips
diff_col = [[] for i in range(k)]
mpos = None
last_pos_in_kmer = k
else:
if reference_kmer[0] != 'M':
multi_meth_pos_set.add((last_read,mpos))
last_mpos = mpos
pos_in_kmer = len(reference_kmer.split('M')[0])
mpos = read_pos + pos_in_kmer
mspacing = min(k,mpos - last_mpos)
last_pos_in_kmer = pos_in_kmer
last_diff_col = diff_col
diffs = [[] for i in range(mspacing)] + diff_col[:-mspacing]
diff_col = diffs
if len(diff_col) != k:
try:
print(last_info,'- n diffs off')
except:
pass
#GGCGCM 613883 613878 False 2289b392-746e-4fa0-8226-d3ac661c9620_Basecall_2D_template 2289b392-746e-4fa0-8226-d3ac661c9620_Basecall_2D_template [[], [], [], [], [], [], []] 7
print(reference_kmer,last_mpos,mpos,mspacing,read_pos,read_pos-last_mpos,read_name,last_read,diff_col,mspacing, last_diff_col, last_diff_col[:-mspacing])
diff_col = [[] for i in range(k)]
sys.exit(0)
#if modified base in reference, save surrounding context to call that position
if 'M' in set(list(reference_kmer)):
pos_in_kmer = [i for i,x in enumerate(list(reference_kmer)) if x == 'M'][0]
#if new read, reset differences variable and proceed
if mpos:
if read_name != last_read:
mpos = None
diff_col = [[] for i in range(k)]
elif rev != last_rev:
mpos = None
#if new read or new position
if not mpos:
mpos = read_pos+pos_in_kmer
last_pos_in_kmer = pos_in_kmer
last_read = read_name
last_rev = rev
last_first = first_read_ind
last_ref = meth_ref
diff_col[pos_in_kmer].append(np.round(float(event_current)-float(model_current),4))
last_pos = read_pos
elif mpos:
mpos = None
diff_col = [[] for i in range(k)]
writefi(towrite,tsv_output)
print('thread finished processing...:')
print('%d observations' %num_observations)
num_pos = len(pos_set)
print('%d positions' %num_pos)
print('%d regions with multiple methylated bases' %len(multi_meth_pos_set))
print('%d observations with skips included' %len(w_skips))
print('%d observations with too many skips' %len(skipped_skips))
if train:
return signals, contexts