diff --git a/caller/call_smn12.py b/caller/call_smn12.py index 68d69fb..c79dac0 100644 --- a/caller/call_smn12.py +++ b/caller/call_smn12.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -26,16 +26,19 @@ import sys from collections import namedtuple, Counter from scipy.stats import poisson -dir_name = os.path.join(os.path.dirname( - os.path.dirname(__file__)), 'depth_calling') + +dir_name = os.path.join(os.path.dirname(os.path.dirname(__file__)), "depth_calling") if os.path.exists(dir_name): sys.path.append(dir_name) -from depth_calling.copy_number_call import call_reg1_cn, process_raw_call_gc, \ - process_raw_call_denovo +from depth_calling.copy_number_call import ( + call_reg1_cn, + process_raw_call_gc, + process_raw_call_denovo, +) SMA_CUTOFF = 1e-6 TOTAL_NUM_SITES = 16 -SELECTED_SITES_INDEX = [a-1 for a in [7, 8, 10, 11, 12, 13, 14, 15]] +SELECTED_SITES_INDEX = [a - 1 for a in [7, 8, 10, 11, 12, 13, 14, 15]] SPLICE_INDEX = 12 POSTERIOR_CUTOFF_STRINGENT = 0.9 POSTERIOR_CUTOFF_MEDIUM = 0.8 @@ -49,19 +52,18 @@ def safe_division(x, y): def get_fraction(lsnp1, lsnp2): """Return the fraction of reads supporting SMN1.""" - return [safe_division(count1, count1 + lsnp2[i]) for i, count1 in - enumerate(lsnp1)] + return [safe_division(count1, count1 + lsnp2[i]) for i, count1 in enumerate(lsnp1)] def smn1_cn_zero(count_smn1, count_smn2, mdepth): """Return the likelihood ratio between SMN1 CN=0 and SMN1 CN=1.""" nsum = count_smn1 + count_smn2 - depthexpected0 = (ERROR_RATE/3) * float(nsum) + depthexpected0 = (ERROR_RATE / 3) * float(nsum) # haploid depth depthexpected1 = float(mdepth) / 2 prob_cp0 = poisson.pmf(count_smn1, depthexpected0) prob_cp1 = poisson.pmf(count_smn1, depthexpected1) - return prob_cp0/prob_cp1 + return prob_cp0 / prob_cp1 def get_raw_smn1_cn(full_length_cn, smn1_fraction): @@ -77,25 +79,35 @@ def get_raw_smn1_cn(full_length_cn, smn1_fraction): def update_full_length_cn(raw_cn_call): """Return the updated full-length SMN CN.""" # full-length CN can't be higher than total CN - cn_call = namedtuple( - 'cn_call', 'exon16_cn exon16_depth exon78_cn exon78_depth') + cn_call = namedtuple("cn_call", "exon16_cn exon16_depth exon78_cn exon78_depth") if raw_cn_call.exon78_depth >= raw_cn_call.exon16_depth: if raw_cn_call.exon78_cn is None and raw_cn_call.exon16_cn is not None: updated_cn_call = cn_call( - raw_cn_call.exon16_cn, raw_cn_call.exon16_depth, - raw_cn_call.exon16_cn, raw_cn_call.exon78_depth) + raw_cn_call.exon16_cn, + raw_cn_call.exon16_depth, + raw_cn_call.exon16_cn, + raw_cn_call.exon78_depth, + ) return updated_cn_call if raw_cn_call.exon78_cn is not None and raw_cn_call.exon16_cn is None: updated_cn_call = cn_call( - raw_cn_call.exon78_cn, raw_cn_call.exon16_depth, - raw_cn_call.exon78_cn, raw_cn_call.exon78_depth) + raw_cn_call.exon78_cn, + raw_cn_call.exon16_depth, + raw_cn_call.exon78_cn, + raw_cn_call.exon78_depth, + ) return updated_cn_call - if raw_cn_call.exon78_cn is not None and \ - raw_cn_call.exon16_cn is not None \ - and raw_cn_call.exon16_cn < raw_cn_call.exon78_cn: + if ( + raw_cn_call.exon78_cn is not None + and raw_cn_call.exon16_cn is not None + and raw_cn_call.exon16_cn < raw_cn_call.exon78_cn + ): updated_cn_call = cn_call( - raw_cn_call.exon16_cn, raw_cn_call.exon16_depth, - raw_cn_call.exon16_cn, raw_cn_call.exon78_depth) + raw_cn_call.exon16_cn, + raw_cn_call.exon16_depth, + raw_cn_call.exon16_cn, + raw_cn_call.exon78_depth, + ) return updated_cn_call return raw_cn_call @@ -104,58 +116,66 @@ def get_smn1_call_and_tag(cn_prob_all, combined_call): """Return the final SMN1 CN call and its tag value.""" cn_prob = [cn_prob_all[a] for a in SELECTED_SITES_INDEX] lsitecall_stringent = process_raw_call_gc( - cn_prob, POSTERIOR_CUTOFF_STRINGENT, keep_none=False) + cn_prob, POSTERIOR_CUTOFF_STRINGENT, keep_none=False + ) lsitecall_medium = process_raw_call_gc( - cn_prob, POSTERIOR_CUTOFF_MEDIUM, keep_none=False) + cn_prob, POSTERIOR_CUTOFF_MEDIUM, keep_none=False + ) lsitecall_loose = process_raw_call_gc( - cn_prob, POSTERIOR_CUTOFF_LOOSE, keep_none=False) + cn_prob, POSTERIOR_CUTOFF_LOOSE, keep_none=False + ) lsitecall_medium_counter = sorted( - Counter(lsitecall_medium).items(), key=lambda kv: kv[1], reverse=True) + Counter(lsitecall_medium).items(), key=lambda kv: kv[1], reverse=True + ) lsitecall_loose_counter = sorted( - Counter(lsitecall_loose).items(), key=lambda kv: kv[1], reverse=True) + Counter(lsitecall_loose).items(), key=lambda kv: kv[1], reverse=True + ) # sliding window of three sites covering the splice site # [11, 12, 13], [12, 13, 14], [13, 14, 15] if lsitecall_loose_counter[0][1] >= 5: for i in [11, 12, 13]: - sliding_window = [i, i+1, i+2] - prob_window = [cn_prob_all[a-1] for a in sliding_window] + sliding_window = [i, i + 1, i + 2] + prob_window = [cn_prob_all[a - 1] for a in sliding_window] call_window_loose = process_raw_call_gc( - prob_window, POSTERIOR_CUTOFF_LOOSE, keep_none=False) - if (len(call_window_loose) == 3 and - (min(call_window_loose) > lsitecall_loose_counter[0][0] or - max(call_window_loose) < lsitecall_loose_counter[0][0])): + prob_window, POSTERIOR_CUTOFF_LOOSE, keep_none=False + ) + if len(call_window_loose) == 3 and ( + min(call_window_loose) > lsitecall_loose_counter[0][0] + or max(call_window_loose) < lsitecall_loose_counter[0][0] + ): cn_smn1 = None - tag = 'SpliceDisagree' + tag = "SpliceDisagree" return tag, cn_smn1, lsitecall_loose # At least 5 sites need to agree if lsitecall_medium_counter[0][1] >= 5: - tag = 'PASS:Majority' + tag = "PASS:Majority" cn_smn1 = lsitecall_medium_counter[0][0] return tag, cn_smn1, lsitecall_loose # When the call summing up all sites is very confident - if (len(combined_call) == 1 and - combined_call[0] == lsitecall_loose_counter[0][0] and - lsitecall_loose_counter[0][1] >= 5): - tag = 'PASS:AgreeWithSum' + if ( + len(combined_call) == 1 + and combined_call[0] == lsitecall_loose_counter[0][0] + and lsitecall_loose_counter[0][1] >= 5 + ): + tag = "PASS:AgreeWithSum" cn_smn1 = lsitecall_medium_counter[0][0] return tag, cn_smn1, lsitecall_loose # The remaining ones will be no-call cn_smn1 = None - tag = 'Ambiguous' + tag = "Ambiguous" return tag, cn_smn1, lsitecall_loose -def get_sma_status( - site_calls, cn_prob, cn_smn1, sma_likelihood_ratio): +def get_sma_status(site_calls, cn_prob, cn_smn1, sma_likelihood_ratio): """Return the SMA status of the sample.""" - if (0 not in cn_prob[SPLICE_INDEX] and - (sma_likelihood_ratio < SMA_CUTOFF or - site_calls.count(0) <= 1)): + if 0 not in cn_prob[SPLICE_INDEX] and ( + sma_likelihood_ratio < SMA_CUTOFF or site_calls.count(0) <= 1 + ): if cn_smn1 == 0: return None else: @@ -169,8 +189,7 @@ def get_sma_status( def get_carrier_status(site_calls, cn_prob, cn_smn1, sma_likelihood_ratio): """Return the carrier status of the sample.""" - if 1 not in cn_prob[SPLICE_INDEX] and \ - len([a for a in site_calls if a <= 1]) <= 1: + if 1 not in cn_prob[SPLICE_INDEX] and len([a for a in site_calls if a <= 1]) <= 1: if cn_smn1 == 1: return None else: @@ -188,36 +207,64 @@ def get_smn12_call(raw_cn_call, lsnp1, lsnp2, var_ref, var_alt, mdepth): """Return the copy nubmer call of SMN1, SMN2 and SMNstar.""" smn1_fraction = get_fraction(lsnp1, lsnp2) smn_call = namedtuple( - 'smn_call', - 'SMN1 SMN2 SMN2delta78 isCarrier isSMA \ - SMN1_CN_raw Info Confidence g27134TG_raw g27134TG_CN') + "smn_call", + "SMN1 SMN2 SMN2delta78 isCarrier isSMA \ + SMN1_CN_raw Info Confidence g27134TG_raw g27134TG_CN", + ) raw_cn_call = update_full_length_cn(raw_cn_call) full_length_cn = raw_cn_call.exon78_cn if full_length_cn is None: # No-call for full-length CN - tag = 'FLCNnoCall' + tag = "FLCNnoCall" full_length_cn = raw_cn_call.exon78_depth raw_smn1_cn = get_raw_smn1_cn(full_length_cn, smn1_fraction) # In cases where full length copy number is no-call, # Test for zero copy of SMN1 at the splice variant site. # If true, report range for SMN2 CN sma_likelihood_ratio = smn1_cn_zero( - lsnp1[SPLICE_INDEX], lsnp2[SPLICE_INDEX], mdepth) + lsnp1[SPLICE_INDEX], lsnp2[SPLICE_INDEX], mdepth + ) if sma_likelihood_ratio > 1 / SMA_CUTOFF: - cn_smn2 = '%i-%i' % ( - math.floor(full_length_cn), math.ceil(full_length_cn)) + cn_smn2 = "%i-%i" % (math.floor(full_length_cn), math.ceil(full_length_cn)) dout = smn_call( - 0, cn_smn2, None, False, True, raw_smn1_cn, tag, - [None]*TOTAL_NUM_SITES, None, None) + 0, + cn_smn2, + None, + False, + True, + raw_smn1_cn, + tag, + [None] * TOTAL_NUM_SITES, + None, + None, + ) elif sma_likelihood_ratio < SMA_CUTOFF: dout = smn_call( - None, None, None, None, False, raw_smn1_cn, tag, - [None]*TOTAL_NUM_SITES, None, None) + None, + None, + None, + None, + False, + raw_smn1_cn, + tag, + [None] * TOTAL_NUM_SITES, + None, + None, + ) else: dout = smn_call( - None, None, None, None, None, raw_smn1_cn, tag, - [None]*TOTAL_NUM_SITES, None, None) + None, + None, + None, + None, + None, + raw_smn1_cn, + tag, + [None] * TOTAL_NUM_SITES, + None, + None, + ) else: full_length_cn = int(full_length_cn) @@ -230,17 +277,19 @@ def get_smn12_call(raw_cn_call, lsnp1, lsnp2, var_ref, var_alt, mdepth): cn_prob.append(call_reg1_cn(full_length_cn, lsnp1[i], lsnp2[i])) # Combine all 6 sites and make a call. combined_call = call_reg1_cn( - full_length_cn, sum([lsnp1[a] for a in SELECTED_SITES_INDEX]), - sum([lsnp2[a] for a in SELECTED_SITES_INDEX])) + full_length_cn, + sum([lsnp1[a] for a in SELECTED_SITES_INDEX]), + sum([lsnp2[a] for a in SELECTED_SITES_INDEX]), + ) - tag, cn_smn1, lsitecall_loose = get_smn1_call_and_tag( - cn_prob, combined_call) + tag, cn_smn1, lsitecall_loose = get_smn1_call_and_tag(cn_prob, combined_call) sma_likelihood_ratio = smn1_cn_zero( - lsnp1[SPLICE_INDEX], lsnp2[SPLICE_INDEX], mdepth) - is_sma = get_sma_status( - lsitecall_loose, cn_prob, cn_smn1, sma_likelihood_ratio) + lsnp1[SPLICE_INDEX], lsnp2[SPLICE_INDEX], mdepth + ) + is_sma = get_sma_status(lsitecall_loose, cn_prob, cn_smn1, sma_likelihood_ratio) is_carrier = get_carrier_status( - lsitecall_loose, cn_prob, cn_smn1, sma_likelihood_ratio) + lsitecall_loose, cn_prob, cn_smn1, sma_likelihood_ratio + ) # targeted variant(s) var_cn_confident = None @@ -249,12 +298,15 @@ def get_smn12_call(raw_cn_call, lsnp1, lsnp2, var_ref, var_alt, mdepth): raw_var_cn = get_raw_smn1_cn(full_length_cn, var_fraction)[0] var_cn = [call_reg1_cn(full_length_cn, var_alt[0], var_ref[0])] var_cn_filtered = process_raw_call_denovo( - var_cn, POSTERIOR_CUTOFF_MEDIUM, POSTERIOR_CUTOFF_LOOSE, - keep_none=False) + var_cn, POSTERIOR_CUTOFF_MEDIUM, POSTERIOR_CUTOFF_LOOSE, keep_none=False + ) if var_cn_filtered != []: var_cn_confident = var_cn_filtered[0] - if var_cn_confident is not None and cn_smn1 is not None \ - and cn_smn1 < var_cn_confident: + if ( + var_cn_confident is not None + and cn_smn1 is not None + and cn_smn1 < var_cn_confident + ): var_cn_confident = cn_smn1 # Call CN for SMN2 and SMN* @@ -263,13 +315,21 @@ def get_smn12_call(raw_cn_call, lsnp1, lsnp2, var_ref, var_alt, mdepth): if raw_cn_call.exon16_cn is not None: cn_smnstar = int(raw_cn_call.exon16_cn) - full_length_cn if cn_smnstar < 0: - raise Exception( - 'Total SMN CN is smaller than full-length SMN CN.') + raise Exception("Total SMN CN is smaller than full-length SMN CN.") if cn_smn1 is not None: cn_smn2 = full_length_cn - cn_smn1 dout = smn_call( - cn_smn1, cn_smn2, cn_smnstar, is_carrier, is_sma, raw_smn1_cn, tag, - cn_prob, raw_var_cn, var_cn_confident) + cn_smn1, + cn_smn2, + cn_smnstar, + is_carrier, + is_sma, + raw_smn1_cn, + tag, + cn_prob, + raw_var_cn, + var_cn_confident, + ) return dout diff --git a/caller/tests/test_call_smn12.py b/caller/tests/test_call_smn12.py index ccab3b9..52e8881 100644 --- a/caller/tests/test_call_smn12.py +++ b/caller/tests/test_call_smn12.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -28,21 +28,20 @@ from ..call_smn12 import update_full_length_cn, get_smn12_call, smn1_cn_zero SMA_CUTOFF = 1e-6 -cn_call = namedtuple( - 'cn_call', 'exon16_cn exon16_depth exon78_cn exon78_depth') +cn_call = namedtuple("cn_call", "exon16_cn exon16_depth exon78_cn exon78_depth") class TestCN(object): def test_smn1_cn_zero(self): likelihood_ratio = smn1_cn_zero(0, 30, 30) - assert likelihood_ratio > 1/SMA_CUTOFF + assert likelihood_ratio > 1 / SMA_CUTOFF likelihood_ratio = smn1_cn_zero(15, 15, 30) assert likelihood_ratio < SMA_CUTOFF likelihood_ratio = smn1_cn_zero(2, 32, 30) - assert likelihood_ratio < 1/SMA_CUTOFF + assert likelihood_ratio < 1 / SMA_CUTOFF assert likelihood_ratio > SMA_CUTOFF likelihood_ratio = smn1_cn_zero(1, 32, 30) - assert likelihood_ratio < 1/SMA_CUTOFF + assert likelihood_ratio < 1 / SMA_CUTOFF assert likelihood_ratio > SMA_CUTOFF def test_update_full_length_cn(self): @@ -82,7 +81,7 @@ def test_p5_nocall(self): assert final_call.SMN2delta78 is None assert final_call.isSMA is False assert final_call.isCarrier is True - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" # when p3 is no-call but p3 depth value is larger than p5, # take CN(p5) as CN(p3) @@ -96,7 +95,7 @@ def test_p3_nocall_larger_than_p5(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is True - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" # =========================================================================== # when p3 is no-call, but we can call SMA based on read count @@ -107,11 +106,11 @@ def test_p3_nocall_issma(self): lsnp2 = [0, 0, 0, 0, 0, 0, 0, 45, 45, 45, 45, 45, 45, 45, 45, 45] final_call = get_smn12_call(raw_cn_call, lsnp1, lsnp2, [60], [0], 30) assert final_call.SMN1 == 0 - assert final_call.SMN2 == '3-4' + assert final_call.SMN2 == "3-4" assert final_call.SMN2delta78 is None assert final_call.isSMA is True assert final_call.isCarrier is False - assert final_call.Info == 'FLCNnoCall' + assert final_call.Info == "FLCNnoCall" # when p3 is no-call, #reads ambiguous at splice site => isSMA is no-call def test_p3_nocall_issma_nocall(self): @@ -124,7 +123,7 @@ def test_p3_nocall_issma_nocall(self): assert final_call.SMN2delta78 is None assert final_call.isSMA is None assert final_call.isCarrier is None - assert final_call.Info == 'FLCNnoCall' + assert final_call.Info == "FLCNnoCall" # when p3 is no-call, not SMA def test_p3_nocall_not_sma(self): @@ -138,7 +137,7 @@ def test_p3_nocall_not_sma(self): # we can tell this sample is not SMA based on read count at splice site assert final_call.isSMA is False assert final_call.isCarrier is None - assert final_call.Info == 'FLCNnoCall' + assert final_call.Info == "FLCNnoCall" # =========================================================================== def test_carrier(self): @@ -151,7 +150,7 @@ def test_carrier(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is True - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" def test_sma(self): raw_cn_call = cn_call(4, 3.98, 3, 3.1) @@ -163,7 +162,7 @@ def test_sma(self): assert final_call.SMN2delta78 == 1 assert final_call.isSMA is True assert final_call.isCarrier is False - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" def test_sma2_noise(self): raw_cn_call = cn_call(4, 3.98, 3, 3.1) @@ -175,7 +174,7 @@ def test_sma2_noise(self): assert final_call.SMN2delta78 == 1 assert final_call.isSMA is True assert final_call.isCarrier is False - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" def test_not_carrier(self): raw_cn_call = cn_call(4, 3.98, 4, 4.02) @@ -187,7 +186,7 @@ def test_not_carrier(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is False - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" def test_not_carrier_majority_rule_agreewithsum(self): raw_cn_call = cn_call(4, 3.98, 4, 4.02) @@ -199,7 +198,7 @@ def test_not_carrier_majority_rule_agreewithsum(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is False - assert final_call.Info == 'PASS:AgreeWithSum' + assert final_call.Info == "PASS:AgreeWithSum" # check that the sites surrounding splice variant site are consistent with # overall call @@ -213,7 +212,7 @@ def test_not_carrier_majority_rule_splicedisagree(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is None - assert final_call.Info == 'SpliceDisagree' + assert final_call.Info == "SpliceDisagree" def test_smn1_nocall_ambiguous(self): raw_cn_call = cn_call(4, 3.98, 4, 4.02) @@ -225,7 +224,7 @@ def test_smn1_nocall_ambiguous(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is None - assert final_call.Info == 'Ambiguous' + assert final_call.Info == "Ambiguous" def test_smn1_nocall_ambiguous_isSMA_false(self): raw_cn_call = cn_call(4, 3.98, 4, 4.02) @@ -237,20 +236,20 @@ def test_smn1_nocall_ambiguous_isSMA_false(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is None - assert final_call.Info == 'Ambiguous' + assert final_call.Info == "Ambiguous" # =========================================================================== def test_0smn1_0smn2(self): raw_cn_call = cn_call(0, 0, 0, 0) - lsnp1 = [0]*16 - lsnp2 = [0]*16 + lsnp1 = [0] * 16 + lsnp2 = [0] * 16 final_call = get_smn12_call(raw_cn_call, lsnp1, lsnp2, [60], [0], 30) assert final_call.SMN1 == 0 assert final_call.SMN2 == 0 assert final_call.SMN2delta78 == 0 assert final_call.isSMA is True assert final_call.isCarrier is False - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" # =========================================================================== def test_target_variant(self): @@ -263,8 +262,8 @@ def test_target_variant(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is False - assert final_call.Info == 'PASS:Majority' - assert final_call.g27134TG_raw == round(4*16/(42+16), 2) + assert final_call.Info == "PASS:Majority" + assert final_call.g27134TG_raw == round(4 * 16 / (42 + 16), 2) assert final_call.g27134TG_CN == 1 def test_target_variant_larger_than_smn1cn(self): @@ -277,8 +276,8 @@ def test_target_variant_larger_than_smn1cn(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is False - assert final_call.Info == 'PASS:Majority' - assert final_call.g27134TG_raw == round(4*42/(42+16), 2) + assert final_call.Info == "PASS:Majority" + assert final_call.g27134TG_raw == round(4 * 42 / (42 + 16), 2) # called 3, update to be equal to SMN1 CN assert final_call.g27134TG_CN == 2 @@ -296,7 +295,7 @@ def test_corner1(self): assert final_call.SMN2delta78 == 1 assert final_call.isSMA is None assert final_call.isCarrier is None - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" # all sites look like 1 and splice site looks like 0 def test_corner2(self): @@ -309,7 +308,7 @@ def test_corner2(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is None assert final_call.isCarrier is None - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" # a mixture of 0s and 1s so isSMA is no-call def test_corner3(self): @@ -322,7 +321,7 @@ def test_corner3(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is None assert final_call.isCarrier is None - assert final_call.Info == 'Ambiguous' + assert final_call.Info == "Ambiguous" # a mixture of 0s and 1s, but enough reads at splice site to say isSMA is # false @@ -336,7 +335,7 @@ def test_corner4(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is False assert final_call.isCarrier is None - assert final_call.Info == 'Ambiguous' + assert final_call.Info == "Ambiguous" # all sites look like 2 and splice site looks like 0 def test_corner5(self): @@ -349,4 +348,4 @@ def test_corner5(self): assert final_call.SMN2delta78 == 0 assert final_call.isSMA is None assert final_call.isCarrier is False - assert final_call.Info == 'PASS:Majority' + assert final_call.Info == "PASS:Majority" diff --git a/depth_calling/bin_count.py b/depth_calling/bin_count.py index 4b4f2ff..8405222 100644 --- a/depth_calling/bin_count.py +++ b/depth_calling/bin_count.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -32,34 +32,41 @@ MAD_CONSTANT = 1.4826 -def get_normed_depth(bamf, region_dic, nCores=1, reference=None, - gc_correct=True): +def get_normed_depth(bamf, region_dic, nCores=1, reference=None, gc_correct=True): """ Return the normalized depth values and coverage stats for a sample given a bam file """ - counts_for_normalization, gc_for_normalization, region_type_cn, \ - read_length = count_reads_and_prepare_for_normalization( - bamf, region_dic, nCores, reference) + counts_for_normalization, gc_for_normalization, region_type_cn, read_length = count_reads_and_prepare_for_normalization( + bamf, region_dic, nCores, reference + ) normed_depth = normalize( - counts_for_normalization, gc_for_normalization, region_type_cn, - read_length, gc_correct) + counts_for_normalization, + gc_for_normalization, + region_type_cn, + read_length, + gc_correct, + ) return normed_depth -def get_normed_depth_from_count(count_file, region_dic, read_length, - gc_correct=True): +def get_normed_depth_from_count(count_file, region_dic, read_length, gc_correct=True): """ Return the normalized depth values and coverage stats for a sample from a count file. """ count_dic = get_count_from_file(count_file) - counts_for_normalization, gc_for_normalization, region_type_cn = \ - process_counts_and_prepare_for_normalization(count_dic, region_dic) + counts_for_normalization, gc_for_normalization, region_type_cn = process_counts_and_prepare_for_normalization( + count_dic, region_dic + ) normed_depth = normalize( - counts_for_normalization, gc_for_normalization, region_type_cn, - read_length, gc_correct) + counts_for_normalization, + gc_for_normalization, + region_type_cn, + read_length, + gc_correct, + ) return normed_depth @@ -67,7 +74,7 @@ def get_normed_depth_from_count(count_file, region_dic, read_length, def median_correction(counts): """Return values corrected by median.""" y_counts = np.array(counts) - y_counts = y_counts/np.median(y_counts) + y_counts = y_counts / np.median(y_counts) return y_counts @@ -75,14 +82,15 @@ def gc_correction(counts, gc, scale_coefficient=0.9): """Return values corrected by GC content.""" y_counts = np.array(counts) x_gc = np.array(gc) - y_counts = y_counts/np.median(y_counts) + y_counts = y_counts / np.median(y_counts) value_lowess = lowess(y_counts, x_gc, return_sorted=False) sample_median = np.median(y_counts) gc_corrected = [] for i in range(len(y_counts)): scale_factor = scale_coefficient * min(y_counts[i], 2) gc_corrected.append( - y_counts[i] + scale_factor * (sample_median - value_lowess[i])) + y_counts[i] + scale_factor * (sample_median - value_lowess[i]) + ) return gc_corrected @@ -96,10 +104,13 @@ def get_read_count(bamfile, region, mapq_cutoff=0): reads = bamfile.fetch(region[0], region[1], region[2]) nreads = 0 for read in reads: - if (read.mapq >= mapq_cutoff and - read.is_secondary == 0 and read.is_supplementary == 0 and - read.reference_start >= region[1] and - read.reference_start < region[2]): + if ( + read.mapq >= mapq_cutoff + and read.is_secondary == 0 + and read.is_supplementary == 0 + and read.reference_start >= region[1] + and read.reference_start < region[2] + ): nreads += 1 return nreads @@ -110,14 +121,18 @@ def mad(list_of_number): return MAD_CONSTANT * np.median([abs(a - med) for a in list_of_number]) -def normalize(counts_for_normalization, gc_for_normalization, region_type_cn, - read_length, gc_correct): +def normalize( + counts_for_normalization, + gc_for_normalization, + region_type_cn, + read_length, + gc_correct, +): """ Return the normalized depth values for a sample. Median normalization and/or GC normalization """ - gc_corrected_depth = gc_correction( - counts_for_normalization, gc_for_normalization) + gc_corrected_depth = gc_correction(counts_for_normalization, gc_for_normalization) if gc_correct is True: # GC normalization corrected_depth = gc_corrected_depth @@ -126,7 +141,7 @@ def normalize(counts_for_normalization, gc_for_normalization, region_type_cn, corrected_depth = median_correction(counts_for_normalization) vmedian = np.median(counts_for_normalization) * read_length - vmad = round(mad(gc_corrected_depth)/np.median(gc_corrected_depth), 3) + vmad = round(mad(gc_corrected_depth) / np.median(gc_corrected_depth), 3) # Also return median depth of this sample and the coverage MAD. norm_count = {} @@ -134,8 +149,8 @@ def normalize(counts_for_normalization, gc_for_normalization, region_type_cn, if vmedian == 0: norm_count.setdefault(region_type, None) else: - norm_count.setdefault(region_type, 2*hap_cn*corrected_depth[i]) - depth_value = namedtuple('depth_value', 'normalized mediandepth mad') + norm_count.setdefault(region_type, 2 * hap_cn * corrected_depth[i]) + depth_value = namedtuple("depth_value", "normalized mediandepth mad") normalized_bin = depth_value(norm_count, vmedian, vmad) return normalized_bin @@ -152,8 +167,9 @@ def get_read_length(reads, number_to_count=2000): return np.median(read_length) -def count_reads_and_prepare_for_normalization(bamf, region_dic, nCores=1, - reference=None): +def count_reads_and_prepare_for_normalization( + bamf, region_dic, nCores=1, reference=None +): """ Return the normalized depth values and coverage stats for a sample given a bam file @@ -164,32 +180,33 @@ def count_reads_and_prepare_for_normalization(bamf, region_dic, nCores=1, gc_for_normalization = [] region_type_cn = OrderedDict() for region_type in region_dic: - if region_type != 'norm': + if region_type != "norm": lcount = [] region_length = None hap_cn = None for (region, gc) in region_dic[region_type]: region_reads = get_read_count(bamfile, region) lcount.append(region_reads) - if '_hapcn' in region[3]: + if "_hapcn" in region[3]: region_length = region[2] - region[1] gc_for_normalization.append(float(gc)) - hap_cn = int(region[3].split('hapcn')[1]) + hap_cn = int(region[3].split("hapcn")[1]) if region_length is None or hap_cn is None: - raise Exception( - 'Problem with region definition. Length not specified.') + raise Exception("Problem with region definition. Length not specified.") count_sum = sum(lcount) - counts_for_normalization.append(count_sum/(hap_cn*region_length)) + counts_for_normalization.append(count_sum / (hap_cn * region_length)) region_type_cn.setdefault(region_type, hap_cn) # Get read length from the last region reads = bamfile.fetch(region[0], region[1], region[2]) read_length = get_read_length(reads) - lregion = [(region[0], region[1], region[2], gc) - for (region, gc) in region_dic['norm']] + lregion = [ + (region[0], region[1], region[2], gc) for (region, gc) in region_dic["norm"] + ] get_normed_depth_bam = partial( - get_normalization_region_values, bam=bamf, reference=reference) + get_normalization_region_values, bam=bamf, reference=reference + ) region_groups = partition(lregion, nCores) pool = mp.Pool(nCores) result = pool.map(get_normed_depth_bam, region_groups) @@ -200,8 +217,7 @@ def count_reads_and_prepare_for_normalization(bamf, region_dic, nCores=1, bamfile.close() - return counts_for_normalization, gc_for_normalization, region_type_cn, \ - read_length + return counts_for_normalization, gc_for_normalization, region_type_cn, read_length def partition(lst, n): @@ -216,7 +232,7 @@ def get_normalization_region_values(l, bam, reference=None): for region in l: num_reads = get_read_count(bamfile, region) region_length = int(region[2]) - int(region[1]) - norm_depth = num_reads/region_length + norm_depth = num_reads / region_length region_gc = float(region[-1]) lcount.append((norm_depth, region_gc)) bamfile.close() @@ -242,26 +258,25 @@ def process_counts_and_prepare_for_normalization(count_dic, region_dic): gc_for_normalization = [] region_type_cn = OrderedDict() for region_type in region_dic: - if region_type != 'norm': + if region_type != "norm": lcount = [] region_length = None hap_cn = None for (region, gc) in region_dic[region_type]: lcount.append(count_dic[region[3]]) - if '_hapcn' in region[3]: + if "_hapcn" in region[3]: region_length = region[2] - region[1] gc_for_normalization.append(float(gc)) - hap_cn = float(region[3].split('hapcn')[1]) + hap_cn = float(region[3].split("hapcn")[1]) if region_length is None or hap_cn is None: - raise Exception( - 'Problem with region definition. Length not specified.') + raise Exception("Problem with region definition. Length not specified.") count_sum = sum(lcount) - counts_for_normalization.append(count_sum/(hap_cn*region_length)) + counts_for_normalization.append(count_sum / (hap_cn * region_length)) region_type_cn.setdefault(region_type, hap_cn) - for (region, gc) in region_dic['norm']: + for (region, gc) in region_dic["norm"]: region_length = int(region[2]) - int(region[1]) - counts_for_normalization.append(count_dic[region[3]]/region_length) + counts_for_normalization.append(count_dic[region[3]] / region_length) gc_for_normalization.append(gc) return counts_for_normalization, gc_for_normalization, region_type_cn diff --git a/depth_calling/copy_number_call.py b/depth_calling/copy_number_call.py index f1db681..badd82e 100644 --- a/depth_calling/copy_number_call.py +++ b/depth_calling/copy_number_call.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -40,9 +40,9 @@ def call_reg1_cn(full_cn, count_reg1, count_reg2, min_read=0): if nsum == 0: return [None] for i in range(full_cn + 1): - depthexpected = float(nsum) * float(i)/float(full_cn) + depthexpected = float(nsum) * float(i) / float(full_cn) if i == 0: - depthexpected = (ERROR_RATE/3) * float(nsum) + depthexpected = (ERROR_RATE / 3) * float(nsum) if i == full_cn: depthexpected = float(nsum) - ERROR_RATE * float(nsum) if count_reg1 <= count_reg2: @@ -53,19 +53,24 @@ def call_reg1_cn(full_cn, count_reg1, count_reg2, min_read=0): sum_prob = sum(prob) if sum_prob == 0: return [None] - post_prob = [float(a)/float(sum_prob) for a in prob] + post_prob = [float(a) / float(sum_prob) for a in prob] if count_reg2 < count_reg1: post_prob = post_prob[::-1] post_prob_sorted = sorted(post_prob, reverse=True) - if post_prob.index(post_prob_sorted[0]) != 0 and \ - count_reg1 <= min_read and count_reg2 >= min_read: + if ( + post_prob.index(post_prob_sorted[0]) != 0 + and count_reg1 <= min_read + and count_reg2 >= min_read + ): return [0] if post_prob_sorted[0] >= POSTERIOR_CUTOFF_STRINGENT: return [post_prob.index(post_prob_sorted[0])] # output the two most likely scenarios cn_prob_filtered = [ - post_prob.index(post_prob_sorted[0]), round(post_prob_sorted[0], 3), - post_prob.index(post_prob_sorted[1]), round(post_prob_sorted[1], 3) + post_prob.index(post_prob_sorted[0]), + round(post_prob_sorted[0], 3), + post_prob.index(post_prob_sorted[1]), + round(post_prob_sorted[1], 3), ] return cn_prob_filtered @@ -88,8 +93,9 @@ def process_raw_call_gc(cn_prob, post_cutoff, keep_none=True): return cn_prob_filtered -def process_raw_call_denovo(cn_prob, post_cutoff1, post_cutoff2, - list_total_cn=None, keep_none=True): +def process_raw_call_denovo( + cn_prob, post_cutoff1, post_cutoff2, list_total_cn=None, keep_none=True +): """ Filter raw CN calls based on posterior probablity cutoff. For de novel variant calling, i.e. non-gene-conversion cases. @@ -108,12 +114,11 @@ def process_raw_call_denovo(cn_prob, post_cutoff1, post_cutoff2, else: if list_total_cn is not None: total_cn = list_total_cn[i] - keep_var = ( - (cn_call[0] > 0 and cn_call[2] > 0) or - (cn_call[0] == total_cn or cn_call[2] == total_cn) + keep_var = (cn_call[0] > 0 and cn_call[2] > 0) or ( + cn_call[0] == total_cn or cn_call[2] == total_cn ) else: - keep_var = (cn_call[0] > 0 and cn_call[2] > 0) + keep_var = cn_call[0] > 0 and cn_call[2] > 0 if cn_call[1] > post_cutoff1: cn_prob_filtered.append(cn_call[0]) elif keep_var: diff --git a/depth_calling/gmm.py b/depth_calling/gmm.py index 24b871a..f653e20 100644 --- a/depth_calling/gmm.py +++ b/depth_calling/gmm.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -44,7 +44,7 @@ def __init__(self, num_state=DEFAULT_GMM_NSTATE): def set_gmm_par(self, dpar_tmp, svid): """Return the complete set of gmm parameters based on depth value.""" if svid not in dpar_tmp: - raise Exception('Variant id %s is not recognized.' % svid) + raise Exception("Variant id %s is not recognized." % svid) gmm_parameter = dpar_tmp[svid] # The gmm parameter file stores the adjustment factor (value_shift) for # all depth values. This value should be one if there is no bias @@ -53,32 +53,33 @@ def set_gmm_par(self, dpar_tmp, svid): # choice and aligner setting. If running the caller on a population, # it may be desirable to run GMM modeling for the population and make # this value more accurate. - self.value_shift = float(gmm_parameter['shift'][0]) + self.value_shift = float(gmm_parameter["shift"][0]) # The means are modeled as 0, 0.5, 1, 1.5, 2, 2.5, etc. # (We could also do 0, 1, 2, 3, 4, 5, etc.) # The gmm parameter file stores the mean depth values for CN=2 and CN=3 self.mu_state = [ - 0, 0.5, - float(gmm_parameter['mean'][0]), - float(gmm_parameter['mean'][1]) + 0, + 0.5, + float(gmm_parameter["mean"][0]), + float(gmm_parameter["mean"][1]), ] - mu_width = float(gmm_parameter['mean'][1]) - \ - float(gmm_parameter['mean'][0]) + mu_width = float(gmm_parameter["mean"][1]) - float(gmm_parameter["mean"][0]) for i in range(self.nstate): if i > 3: self.mu_state.append(1 + mu_width * (i - 2)) - sum_prior = sum([float(a) for a in gmm_parameter['prior']]) + sum_prior = sum([float(a) for a in gmm_parameter["prior"]]) if sum_prior >= 1: - raise Exception('Sum of priors is larger than 1.') + raise Exception("Sum of priors is larger than 1.") for i in range(self.nstate): # The gmm parameter file stores the prior frequencies for CN=0-6 - if i < len(gmm_parameter['prior']): - self.prior_state.append(float(gmm_parameter['prior'][i])) + if i < len(gmm_parameter["prior"]): + self.prior_state.append(float(gmm_parameter["prior"][i])) else: - prior_value = (1 - sum_prior)/(self.nstate - - len(gmm_parameter['prior'])) + prior_value = (1 - sum_prior) / ( + self.nstate - len(gmm_parameter["prior"]) + ) self.prior_state.append(prior_value) self.sigma_state = [SIGMA_CN0] @@ -86,28 +87,28 @@ def set_gmm_par(self, dpar_tmp, svid): # The other sd values are derived from it. for i in range(self.nstate): if i > 0: - sigma_value = float( - gmm_parameter['sd'][0]) * np.sqrt(float(i)/2) + sigma_value = float(gmm_parameter["sd"][0]) * np.sqrt(float(i) / 2) self.sigma_state.append(sigma_value) def gmm_call(self, val): """Return the final copy number call.""" - val_new = (val/2)/self.value_shift + val_new = (val / 2) / self.value_shift fcall = self.call_post_prob(val_new, POSTERIOR_CUTOFF) if fcall is not None: gauss_p_value = self.get_gauss_pmf_cdf( - val_new, self.mu_state[fcall], self.sigma_state[fcall])[1] + val_new, self.mu_state[fcall], self.sigma_state[fcall] + )[1] # apply another p-value cutoff # just comparing the depth value and the called CN if gauss_p_value < PV_CUTOFF: fcall = None - cn_call = namedtuple('cn_call', 'cn depth_value') - return cn_call(fcall, round(val/self.value_shift, 3)) + cn_call = namedtuple("cn_call", "cn depth_value") + return cn_call(fcall, round(val / self.value_shift, 3)) def get_gauss_pmf_cdf(self, test_value, gauss_mean, gauss_sd): """Return the pmf and cdf of a gaussian distribution.""" - test_stats = (test_value - gauss_mean)/gauss_sd - pdf = norm.pdf(test_stats)/gauss_sd + test_stats = (test_value - gauss_mean) / gauss_sd + pdf = norm.pdf(test_stats) / gauss_sd p_value = min(norm.cdf(test_stats), 1 - norm.cdf(test_stats)) return (pdf, p_value) @@ -117,10 +118,11 @@ def call_post_prob(self, val, post_cutoff): prob = [] for i in range(0, number_state): gauss_pmf = self.get_gauss_pmf_cdf( - val, self.mu_state[i], self.sigma_state[i])[0] + val, self.mu_state[i], self.sigma_state[i] + )[0] prob.append(gauss_pmf * self.prior_state[i]) sum_prob = float(sum(prob)) - post_prob = [float(a)/sum_prob for a in prob] + post_prob = [float(a) / sum_prob for a in prob] max_prob = max(post_prob) if max_prob >= post_cutoff: return post_prob.index(max_prob) diff --git a/depth_calling/snp_count.py b/depth_calling/snp_count.py index 11efc3a..17d2062 100644 --- a/depth_calling/snp_count.py +++ b/depth_calling/snp_count.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -25,19 +25,19 @@ from .utilities import open_alignment_file -COMPLEMENT = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N': 'N'} +COMPLEMENT = {"A": "T", "T": "A", "C": "G", "G": "C", "N": "N"} SITES_STRINGENT = [] # consider being more stringent for exon8 site for SMN def reverse_complement(sequence): """Return the reverse complement of a sequence.""" - return ''.join(COMPLEMENT[b] for b in sequence[::-1]) + return "".join(COMPLEMENT[b] for b in sequence[::-1]) def get_nm(ltag): """Return the value of the NM tag.""" for tag in ltag: - if tag[0] == 'NM': + if tag[0] == "NM": return tag[1] return None @@ -51,41 +51,40 @@ def get_snp_position(pos_file): with open(pos_file) as read_pos: counter = -1 for line in read_pos: - if line[0] != '#' and line[0] != '\n': + if line[0] != "#" and line[0] != "\n": counter += 1 split_line = line.strip().split() - reg1_name = split_line[1] + '_' + str(counter) - reg2_name = split_line[3] + '_' + str(counter) + reg1_name = split_line[1] + "_" + str(counter) + reg2_name = split_line[3] + "_" + str(counter) reg1_base = split_line[2].upper() reg2_base = split_line[4].upper() - if split_line[-1] != '-': - dsnp1.setdefault(reg1_name, '_'.join( - [reg1_base, reg2_base])) - dsnp2.setdefault(reg2_name, '_'.join( - [reg1_base, reg2_base])) + if split_line[-1] != "-": + dsnp1.setdefault(reg1_name, "_".join([reg1_base, reg2_base])) + dsnp2.setdefault(reg2_name, "_".join([reg1_base, reg2_base])) else: dsnp1.setdefault( - reg1_name, - '_'.join([reg1_base, reverse_complement(reg2_base)]) + reg1_name, "_".join([reg1_base, reverse_complement(reg2_base)]) ) dsnp2.setdefault( - reg2_name, - '_'.join([reverse_complement(reg1_base), reg2_base]) + reg2_name, "_".join([reverse_complement(reg1_base), reg2_base]) ) dindex.setdefault(reg1_name, counter) dindex.setdefault(reg2_name, counter) nchr = split_line[0] - snp_lookup = namedtuple('snp_lookup', 'dsnp1 dsnp2 nchr dindex') + snp_lookup = namedtuple("snp_lookup", "dsnp1 dsnp2 nchr dindex") dbsnp = snp_lookup(dsnp1, dsnp2, nchr, dindex) return dbsnp def passing_read(pileupread): """Return whether a read passes filter.""" - return (not pileupread.is_del and not pileupread.is_refskip and - pileupread.alignment.is_secondary == 0 and - pileupread.alignment.is_supplementary == 0 and - pileupread.alignment.is_duplicate == 0) + return ( + not pileupread.is_del + and not pileupread.is_refskip + and pileupread.alignment.is_secondary == 0 + and pileupread.alignment.is_supplementary == 0 + and pileupread.alignment.is_duplicate == 0 + ) def passing_read_stringent(pileupread): @@ -93,9 +92,11 @@ def passing_read_stringent(pileupread): number_mismatch = get_nm(pileupread.alignment.tags) align_len = pileupread.alignment.query_alignment_length read_len = len(pileupread.alignment.query_sequence) - return (number_mismatch <= float(align_len) * 0.08 and - pileupread.query_position > 0 and - pileupread.query_position < read_len - 1) + return ( + number_mismatch <= float(align_len) * 0.08 + and pileupread.query_position > 0 + and pileupread.query_position < read_len - 1 + ) def get_reads_by_region(bamfile_handle, nchr, dsnp, dindex, min_mapq=0): @@ -105,23 +106,32 @@ def get_reads_by_region(bamfile_handle, nchr, dsnp, dindex, min_mapq=0): lsnp1 = [0] * len(dsnp) lsnp2 = [0] * len(dsnp) for snp_position_ori in dsnp: - snp_position = int(snp_position_ori.split('_')[0]) + snp_position = int(snp_position_ori.split("_")[0]) for pileupcolumn in bamfile_handle.pileup( - nchr, snp_position - 1, snp_position + 1, truncate=True, - stepper='nofilter', ignore_overlaps=False, ignore_orphan=False + nchr, + snp_position - 1, + snp_position + 1, + truncate=True, + stepper="nofilter", + ignore_overlaps=False, + ignore_orphan=False, ): site_position = pileupcolumn.pos + 1 if site_position == snp_position: - reg1_allele, reg2_allele = dsnp[snp_position_ori].split('_') + reg1_allele, reg2_allele = dsnp[snp_position_ori].split("_") for read in pileupcolumn.pileups: - if passing_read(read) and \ - read.alignment.mapping_quality >= min_mapq: + if ( + passing_read(read) + and read.alignment.mapping_quality >= min_mapq + ): dsnp_index = dindex[snp_position_ori] read_seq = read.alignment.query_sequence - if (site_position not in SITES_STRINGENT or - passing_read_stringent(read)): - reg1_allele_split = reg1_allele.split(',') - reg2_allele_split = reg2_allele.split(',') + if ( + site_position not in SITES_STRINGENT + or passing_read_stringent(read) + ): + reg1_allele_split = reg1_allele.split(",") + reg2_allele_split = reg2_allele.split(",") start_pos = read.query_position for allele in reg1_allele_split: end_pos = start_pos + len(allele) @@ -142,7 +152,7 @@ def get_fraction(lsnp1, lsnp2): if sumdepth == 0: reg1_fraction.append(0) else: - reg1_fraction.append(float(lsnp1[index])/float(sumdepth)) + reg1_fraction.append(float(lsnp1[index]) / float(sumdepth)) return reg1_fraction @@ -155,10 +165,8 @@ def get_supporting_reads(bamf, dsnp1, dsnp2, nchr, dindex, reference=None): assert len(dsnp1) == len(dsnp2) # Go through SNP sites in both regions, # and count the number of reads supporting each gene. - lsnp1_reg1, lsnp2_reg1 = get_reads_by_region( - bamfile_handle, nchr, dsnp1, dindex) - lsnp1_reg2, lsnp2_reg2 = get_reads_by_region( - bamfile_handle, nchr, dsnp2, dindex) + lsnp1_reg1, lsnp2_reg1 = get_reads_by_region(bamfile_handle, nchr, dsnp1, dindex) + lsnp1_reg2, lsnp2_reg2 = get_reads_by_region(bamfile_handle, nchr, dsnp2, dindex) lsnp1 = [sum(x) for x in zip(lsnp1_reg1, lsnp1_reg2)] lsnp2 = [sum(x) for x in zip(lsnp2_reg1, lsnp2_reg2)] bamfile_handle.close() diff --git a/depth_calling/tests/test_bin_count.py b/depth_calling/tests/test_bin_count.py index 03e787c..4dd1c47 100644 --- a/depth_calling/tests/test_bin_count.py +++ b/depth_calling/tests/test_bin_count.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -26,32 +26,37 @@ import pysam -from ..bin_count import get_read_count, get_read_length, mad, normalize, \ - get_normed_depth +from ..bin_count import ( + get_read_count, + get_read_length, + mad, + normalize, + get_normed_depth, +) from ..utilities import parse_region_file, open_alignment_file -test_data_dir = os.path.join(os.path.dirname(__file__), 'test_data') +test_data_dir = os.path.join(os.path.dirname(__file__), "test_data") class TestBinCount(object): def test_read_count_from_bam(self): - bam = os.path.join(test_data_dir, 'NA12878.bam') + bam = os.path.join(test_data_dir, "NA12878.bam") bamfile = open_alignment_file(bam) - region1 = ('5', 69372349, 69372400) + region1 = ("5", 69372349, 69372400) region1_count = get_read_count(bamfile, region1) assert region1_count == 11 - region2 = ('5', 70248246, 70248303) + region2 = ("5", 70248246, 70248303) region2_count = get_read_count(bamfile, region2) assert region2_count == 30 bamfile.close() def test_get_readlength(self): - bam = os.path.join(test_data_dir, 'NA12878.bam') + bam = os.path.join(test_data_dir, "NA12878.bam") bamfile = open_alignment_file(bam) - reads = bamfile.fetch('5', 69372349, 70248303) + reads = bamfile.fetch("5", 69372349, 70248303) read_length = get_read_length(reads) assert read_length == 150 bamfile.close() @@ -63,25 +68,27 @@ def test_mad(self): assert mad(list_number) == 2.2239 def test_normalize(self): - counts_for_normalization = [0.3, 0.25, 0.3, - 0.228, 0.29, 0.35, 0.31, 0.38, 0.42] - gc_for_normalization = [0.42, 0.42, 0.43, - 0.39, 0.4, 0.45, 0.43, 0.5, 0.6] - region_type_cn = {'exon16': 2, 'exon78': 2} + counts_for_normalization = [0.3, 0.25, 0.3, 0.228, 0.29, 0.35, 0.31, 0.38, 0.42] + gc_for_normalization = [0.42, 0.42, 0.43, 0.39, 0.4, 0.45, 0.43, 0.5, 0.6] + region_type_cn = {"exon16": 2, "exon78": 2} norm = normalize( - counts_for_normalization, gc_for_normalization, - region_type_cn, 150, gc_correct=False) - assert norm.normalized['exon16'] == 4 - assert round(norm.normalized['exon78'], 3) == 3.333 + counts_for_normalization, + gc_for_normalization, + region_type_cn, + 150, + gc_correct=False, + ) + assert norm.normalized["exon16"] == 4 + assert round(norm.normalized["exon78"], 3) == 3.333 assert norm.mediandepth == 45 assert round(norm.mad, 5) == 0.057 def test_bin_count(self): - bam = os.path.join(test_data_dir, 'NA12885.bam') - region_file = os.path.join(test_data_dir, 'SMN_region_37_short.bed') + bam = os.path.join(test_data_dir, "NA12885.bam") + region_file = os.path.join(test_data_dir, "SMN_region_37_short.bed") region_dic = parse_region_file(region_file) normed_depth = get_normed_depth(bam, region_dic, gc_correct=False) - assert round(normed_depth.normalized['exon16'], 3) == 3.876 - assert round(normed_depth.normalized['exon78'], 3) == 4.024 + assert round(normed_depth.normalized["exon16"], 3) == 3.876 + assert round(normed_depth.normalized["exon78"], 3) == 4.024 assert round(normed_depth.mediandepth, 2) == 48.75 assert round(normed_depth.mad, 5) == 0.066 diff --git a/depth_calling/tests/test_copy_number_call.py b/depth_calling/tests/test_copy_number_call.py index 84ca9dc..d69fa15 100644 --- a/depth_calling/tests/test_copy_number_call.py +++ b/depth_calling/tests/test_copy_number_call.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -24,8 +24,11 @@ import os import pytest -from ..copy_number_call import call_reg1_cn, process_raw_call_gc, \ - process_raw_call_denovo +from ..copy_number_call import ( + call_reg1_cn, + process_raw_call_gc, + process_raw_call_denovo, +) class TestCallCN(object): @@ -40,10 +43,7 @@ def test_call_reg1_cn(self): assert call == [2, 0.689, 1, 0.311] def test_process_raw_call_gc(self): - lcn = [ - [1], - [1, 0.8, 2, 0.2] - ] + lcn = [[1], [1, 0.8, 2, 0.2]] filtered_call = process_raw_call_gc(lcn, 0.7) assert filtered_call == [1, 1] filtered_call = process_raw_call_gc(lcn, 0.9) @@ -52,21 +52,13 @@ def test_process_raw_call_gc(self): assert filtered_call == [1] def test_process_raw_call_denovo(self): - lcn = [ - [1], - [0, 0.8, 1, 0.2], - [2, 0.6, 1, 0.4], - ] + lcn = [[1], [0, 0.8, 1, 0.2], [2, 0.6, 1, 0.4]] filtered_call = process_raw_call_denovo(lcn, 0.9, 0.7) assert filtered_call == [1, None, 1] filtered_call = process_raw_call_denovo(lcn, 0.9, 0.55) assert filtered_call == [1, None, 2] - lcn = [ - [1], - [1, 0.65, 0, 0.35], - [2, 0.6, 1, 0.4], - ] + lcn = [[1], [1, 0.65, 0, 0.35], [2, 0.6, 1, 0.4]] filtered_call = process_raw_call_denovo(lcn, 0.9, 0.62, [1, 1, 1]) assert filtered_call == [1, 1, 1] filtered_call = process_raw_call_denovo(lcn, 0.9, 0.7, [1, 1, 1]) diff --git a/depth_calling/tests/test_gmm.py b/depth_calling/tests/test_gmm.py index ae6cee5..02b6bef 100644 --- a/depth_calling/tests/test_gmm.py +++ b/depth_calling/tests/test_gmm.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -29,16 +29,16 @@ from ..gmm import Gmm from ..utilities import parse_gmm_file -test_data_dir = os.path.join(os.path.dirname(__file__), 'test_data') +test_data_dir = os.path.join(os.path.dirname(__file__), "test_data") class TestGMM(object): def test_gmm_parameter(self): - gmm_file = os.path.join(test_data_dir, 'SMN_gmm.txt') + gmm_file = os.path.join(test_data_dir, "SMN_gmm.txt") dpar_tmp = parse_gmm_file(gmm_file) test_gmm = Gmm() - test_gmm.set_gmm_par(dpar_tmp, 'exon1-6') + test_gmm.set_gmm_par(dpar_tmp, "exon1-6") assert test_gmm.value_shift == 0.994 assert len(test_gmm.mu_state) == 11 assert len(test_gmm.prior_state) == 11 @@ -51,13 +51,17 @@ def test_gmm_parameter(self): assert round(test_gmm.prior_state[0], 4) == 0.001 assert round(test_gmm.prior_state[8], 4) == 0.0003 assert test_gmm.sigma_state[0:4] == [ - 0.032, 0.051/np.sqrt(2), 0.051, 0.051*np.sqrt(1.5)] + 0.032, + 0.051 / np.sqrt(2), + 0.051, + 0.051 * np.sqrt(1.5), + ] def test_gmmcall(self): - gmm_file = os.path.join(test_data_dir, 'SMN_gmm.txt') + gmm_file = os.path.join(test_data_dir, "SMN_gmm.txt") dpar_tmp = parse_gmm_file(gmm_file) test_gmm = Gmm() - test_gmm.set_gmm_par(dpar_tmp, 'exon1-6') + test_gmm.set_gmm_par(dpar_tmp, "exon1-6") cncall = test_gmm.gmm_call(2.1) assert cncall[0] == 2 cncall = test_gmm.gmm_call(6.48) @@ -67,7 +71,7 @@ def test_gmmcall(self): assert cncall[0] is None test_gmm = Gmm() - test_gmm.set_gmm_par(dpar_tmp, 'exon7-8') + test_gmm.set_gmm_par(dpar_tmp, "exon7-8") cncall = test_gmm.gmm_call(0.18) assert cncall[0] == 0 cncall = test_gmm.gmm_call(0.95) diff --git a/depth_calling/tests/test_snp_count.py b/depth_calling/tests/test_snp_count.py index 12d53fc..2e5f199 100644 --- a/depth_calling/tests/test_snp_count.py +++ b/depth_calling/tests/test_snp_count.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -24,99 +24,91 @@ import os import pytest -from ..snp_count import get_snp_position, get_supporting_reads, \ - get_supporting_reads_single_region, get_fraction +from ..snp_count import ( + get_snp_position, + get_supporting_reads, + get_supporting_reads_single_region, + get_fraction, +) TOTAL_NUM_SITES = 16 -test_data_dir = os.path.join(os.path.dirname(__file__), 'test_data') +test_data_dir = os.path.join(os.path.dirname(__file__), "test_data") class TestParseSNPFile(object): def test_parse_snp_file(self): - snp_file = os.path.join(test_data_dir, 'SMN_SNP_37.txt') + snp_file = os.path.join(test_data_dir, "SMN_SNP_37.txt") dsnp1, dsnp2, nchr, dindex = get_snp_position(snp_file) assert len(dsnp1) == TOTAL_NUM_SITES assert len(dsnp2) == TOTAL_NUM_SITES - assert dsnp1['70247773_12'] == 'C_T' - assert dsnp2['69372353_12'] == 'C_T' - assert dsnp1['70247724_11'] == 'G_A' - assert dsnp2['69372304_11'] == 'G_A' - assert nchr == '5' + assert dsnp1["70247773_12"] == "C_T" + assert dsnp2["69372353_12"] == "C_T" + assert dsnp1["70247724_11"] == "G_A" + assert dsnp2["69372304_11"] == "G_A" + assert nchr == "5" - snp_file = os.path.join(test_data_dir, 'SMN_SNP_19.txt') + snp_file = os.path.join(test_data_dir, "SMN_SNP_19.txt") dsnp1, dsnp2, nchr, dindex = get_snp_position(snp_file) assert len(dsnp1) == TOTAL_NUM_SITES assert len(dsnp2) == TOTAL_NUM_SITES - assert dsnp1['70247773_12'] == 'C_T' - assert dsnp2['69372353_12'] == 'C_T' - assert dsnp1['70247724_11'] == 'G_A' - assert dsnp2['69372304_11'] == 'G_A' - assert nchr == 'chr5' + assert dsnp1["70247773_12"] == "C_T" + assert dsnp2["69372353_12"] == "C_T" + assert dsnp1["70247724_11"] == "G_A" + assert dsnp2["69372304_11"] == "G_A" + assert nchr == "chr5" - snp_file = os.path.join(test_data_dir, 'SMN_SNP_38.txt') + snp_file = os.path.join(test_data_dir, "SMN_SNP_38.txt") dsnp1, dsnp2, nchr, dindex = get_snp_position(snp_file) assert len(dsnp1) == TOTAL_NUM_SITES assert len(dsnp2) == TOTAL_NUM_SITES - assert dsnp1['70951946_12'] == 'C_T' - assert dsnp2['70076526_12'] == 'C_T' - assert dsnp1['70951463_10'] == 'T_C' - assert dsnp2['70076043_10'] == 'T_C' - assert nchr == 'chr5' + assert dsnp1["70951946_12"] == "C_T" + assert dsnp2["70076526_12"] == "C_T" + assert dsnp1["70951463_10"] == "T_C" + assert dsnp2["70076043_10"] == "T_C" + assert nchr == "chr5" # test indels and reverse complement - snp_file = os.path.join(test_data_dir, 'SMN_SNP_37_test.txt') + snp_file = os.path.join(test_data_dir, "SMN_SNP_37_test.txt") dsnp1, dsnp2, nchr, dindex = get_snp_position(snp_file) assert len(dsnp1) == TOTAL_NUM_SITES assert len(dsnp2) == TOTAL_NUM_SITES - assert dsnp1['70245876_1'] == 'T_G' - assert dsnp2['69370451_1'] == 'A_C' - assert dsnp1['70246016_2'] == 'G_T' - assert dsnp2['69370591_2'] == 'C_A' - assert dsnp1['70248108_15'] == 'CAC_CC' - assert dsnp2['69372688_15'] == 'CAC_CC' + assert dsnp1["70245876_1"] == "T_G" + assert dsnp2["69370451_1"] == "A_C" + assert dsnp1["70246016_2"] == "G_T" + assert dsnp2["69370591_2"] == "C_A" + assert dsnp1["70248108_15"] == "CAC_CC" + assert dsnp2["69372688_15"] == "CAC_CC" class TestReadCount(object): def test_get_snp_count(self): - snp_file = os.path.join(test_data_dir, 'SMN_SNP_37.txt') + snp_file = os.path.join(test_data_dir, "SMN_SNP_37.txt") dsnp1, dsnp2, nchr, dindex = get_snp_position(snp_file) - bam1 = os.path.join(test_data_dir, 'NA12878.bam') - lsnp1, lsnp2 = get_supporting_reads( - bam1, dsnp1, dsnp2, nchr, dindex) - assert lsnp1 == [ - 0, 0, 0, 0, 0, 0, 29, 35, 26, 39, 29, 35, 32, 37, 39, 39] - assert lsnp2 == [ - 0, 0, 0, 0, 0, 0, 12, 39, 39, 32, 26, 55, 45, 33, 42, 18] + bam1 = os.path.join(test_data_dir, "NA12878.bam") + lsnp1, lsnp2 = get_supporting_reads(bam1, dsnp1, dsnp2, nchr, dindex) + assert lsnp1 == [0, 0, 0, 0, 0, 0, 29, 35, 26, 39, 29, 35, 32, 37, 39, 39] + assert lsnp2 == [0, 0, 0, 0, 0, 0, 12, 39, 39, 32, 26, 55, 45, 33, 42, 18] - bam2 = os.path.join(test_data_dir, 'NA12885.bam') - lsnp1, lsnp2 = get_supporting_reads( - bam2, dsnp1, dsnp2, nchr, dindex) - assert lsnp1 == [ - 46, 32, 45, 36, 34, 14, 36, 54, 38, 34, 41, 41, 40, 51, 40, 37] - assert lsnp2 == [ - 35, 35, 32, 29, 35, 59, 22, 28, 32, 24, 34, 32, 33, 28, 38, 21] + bam2 = os.path.join(test_data_dir, "NA12885.bam") + lsnp1, lsnp2 = get_supporting_reads(bam2, dsnp1, dsnp2, nchr, dindex) + assert lsnp1 == [46, 32, 45, 36, 34, 14, 36, 54, 38, 34, 41, 41, 40, 51, 40, 37] + assert lsnp2 == [35, 35, 32, 29, 35, 59, 22, 28, 32, 24, 34, 32, 33, 28, 38, 21] - lsnp1, lsnp2 = get_supporting_reads_single_region( - bam2, dsnp1, nchr, dindex) - assert lsnp1 == [ - 46, 32, 45, 36, 26, 14, 36, 54, 38, 34, 41, 41, 40, 51, 40, 34] - assert lsnp2 == [ - 0, 1, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + lsnp1, lsnp2 = get_supporting_reads_single_region(bam2, dsnp1, nchr, dindex) + assert lsnp1 == [46, 32, 45, 36, 26, 14, 36, 54, 38, 34, 41, 41, 40, 51, 40, 34] + assert lsnp2 == [0, 1, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # test indels and reverse complement - snp_file = os.path.join(test_data_dir, 'SMN_SNP_37_test.txt') + snp_file = os.path.join(test_data_dir, "SMN_SNP_37_test.txt") dsnp1, dsnp2, nchr, dindex = get_snp_position(snp_file) - bam2 = os.path.join(test_data_dir, 'NA12885.bam') - lsnp1, lsnp2 = get_supporting_reads_single_region( - bam2, dsnp1, nchr, dindex) - assert lsnp1 == [ - 46, 32, 45, 36, 26, 14, 36, 54, 38, 34, 41, 41, 40, 51, 40, 20] - assert lsnp2 == [ - 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16] + bam2 = os.path.join(test_data_dir, "NA12885.bam") + lsnp1, lsnp2 = get_supporting_reads_single_region(bam2, dsnp1, nchr, dindex) + assert lsnp1 == [46, 32, 45, 36, 26, 14, 36, 54, 38, 34, 41, 41, 40, 51, 40, 20] + assert lsnp2 == [0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16] def test_get_fraction(self): lsnp1 = [16, 15, 32, 25, 28, 0] lsnp2 = [40, 45, 31, 30, 27, 0] smn1_fraction = get_fraction(lsnp1, lsnp2) - assert smn1_fraction == [16/56, 15/60, 32/63, 25/55, 28/55, 0] + assert smn1_fraction == [16 / 56, 15 / 60, 32 / 63, 25 / 55, 28 / 55, 0] diff --git a/depth_calling/tests/test_utilities.py b/depth_calling/tests/test_utilities.py index 9bf0058..33df662 100644 --- a/depth_calling/tests/test_utilities.py +++ b/depth_calling/tests/test_utilities.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -26,13 +26,13 @@ from ..utilities import parse_region_file -test_data_dir = os.path.join(os.path.dirname(__file__), 'test_data') +test_data_dir = os.path.join(os.path.dirname(__file__), "test_data") class TestUtilities(object): def test_parse_reigon_file(self): - region_file = os.path.join(test_data_dir, 'SMN_region_19_short.bed') + region_file = os.path.join(test_data_dir, "SMN_region_19_short.bed") region_dic = parse_region_file(region_file) - assert len(region_dic['norm']) == 500 - assert len(region_dic['exon16']) == 2 - assert len(region_dic['exon78']) == 2 + assert len(region_dic["norm"]) == 500 + assert len(region_dic["exon16"]) == 2 + assert len(region_dic["exon78"]) == 2 diff --git a/depth_calling/utilities.py b/depth_calling/utilities.py index 161e289..61ce420 100644 --- a/depth_calling/utilities.py +++ b/depth_calling/utilities.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -29,8 +29,9 @@ def parse_region_file(region_file): region_dic = {} with open(region_file) as read_region: for line in read_region: - nchr, region_start, region_end, region_name, region_type, \ - region_gc = line.strip().split() + nchr, region_start, region_end, region_name, region_type, region_gc = ( + line.strip().split() + ) region_start = int(region_start) region_end = int(region_end) region = (nchr, region_start, region_end, region_name) @@ -45,12 +46,14 @@ def parse_gmm_file(gmm_file): for line in read_gmm: split_line = line.strip().split() dpar_tmp.setdefault(split_line[0], {}) - list_value = [a.split(':')[-1] for a in split_line[2:]] + list_value = [a.split(":")[-1] for a in split_line[2:]] dpar_tmp[split_line[0]].setdefault(split_line[1], list_value) return dpar_tmp def open_alignment_file(alignment_file, reference_fasta=None): - if alignment_file.endswith('cram'): - return pysam.AlignmentFile(alignment_file, 'rc', reference_filename=reference_fasta) - return pysam.AlignmentFile(alignment_file, 'rb') \ No newline at end of file + if alignment_file.endswith("cram"): + return pysam.AlignmentFile( + alignment_file, "rc", reference_filename=reference_fasta + ) + return pysam.AlignmentFile(alignment_file, "rb") diff --git a/smn_caller.py b/smn_caller.py index aba84fd..7489076 100755 --- a/smn_caller.py +++ b/smn_caller.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # SMNCopyNumberCaller -# Copyright 2019 Illumina, Inc. +# Copyright 2019-2020 Illumina, Inc. # All rights reserved. # # Author: Xiao Chen @@ -30,12 +30,18 @@ import pysam -from depth_calling.snp_count import get_supporting_reads, get_fraction, \ - get_snp_position +from depth_calling.snp_count import get_supporting_reads, get_fraction, get_snp_position from depth_calling.gmm import Gmm -from depth_calling.utilities import parse_gmm_file, parse_region_file, open_alignment_file -from depth_calling.bin_count import get_normed_depth, \ - get_normed_depth_from_count, get_read_length +from depth_calling.utilities import ( + parse_gmm_file, + parse_region_file, + open_alignment_file, +) +from depth_calling.bin_count import ( + get_normed_depth, + get_normed_depth_from_count, + get_read_length, +) from caller.call_smn12 import get_smn12_call MAD_THRESHOLD = 0.11 @@ -44,36 +50,52 @@ def load_parameters(): """Return parameters.""" parser = argparse.ArgumentParser( - description='Call Copy number of full-length SMN1, full-length SMN2 and \ - SMN* (Exon7-8 deletion) from a WGS bam file.') - parser.add_argument( - '--manifest', help='Manifest listing absolute paths to input BAM/CRAM files', - required=True) - parser.add_argument( - '--genome', help='Reference genome, select from 19, 37, or 38', - required=True) + description="Call Copy number of full-length SMN1, full-length SMN2 and \ + SMN* (Exon7-8 deletion) from a WGS bam file." + ) parser.add_argument( - '--outDir', help='Output directory', required=True) + "--manifest", + help="Manifest listing absolute paths to input BAM/CRAM files", + required=True, + ) parser.add_argument( - '--prefix', help='Prefix to output file', required=True) + "--genome", help="Reference genome, select from 19, 37, or 38", required=True + ) + parser.add_argument("--outDir", help="Output directory", required=True) + parser.add_argument("--prefix", help="Prefix to output file", required=True) parser.add_argument( - '--threads', help='Number of threads to use. Default is 1', type=int, default=1, - required=False) + "--threads", + help="Number of threads to use. Default is 1", + type=int, + default=1, + required=False, + ) parser.add_argument( - '--reference', help='Optional path to reference fasta file for CRAM', required=False) + "--reference", + help="Optional path to reference fasta file for CRAM", + required=False, + ) parser.add_argument( - '--countFilePath', help='Optional path to count files', required=False) + "--countFilePath", help="Optional path to count files", required=False + ) args = parser.parse_args() - if args.genome not in ['19', '37', '38']: - raise Exception('Genome not recognized. Select from 19, 37, or 38') + if args.genome not in ["19", "37", "38"]: + raise Exception("Genome not recognized. Select from 19, 37, or 38") return args def smn_cn_caller( - bam, region_dic, gmm_parameter, - snp_db, variant_db, threads, count_file=None, reference_fasta=None): + bam, + region_dic, + gmm_parameter, + snp_db, + variant_db, + threads, + count_file=None, + reference_fasta=None, +): """Return SMN CN calls for each sample.""" # 1. read counting, normalization if count_file is not None: @@ -82,56 +104,75 @@ def smn_cn_caller( read_length = get_read_length(reads) bamfile.close() normalized_depth = get_normed_depth_from_count( - count_file, region_dic, read_length, gc_correct=False) + count_file, region_dic, read_length, gc_correct=False + ) else: normalized_depth = get_normed_depth( - bam, region_dic, threads, reference=reference_fasta, gc_correct=False) + bam, region_dic, threads, reference=reference_fasta, gc_correct=False + ) # 2. GMM and CN call - cn_call = namedtuple( - 'cn_call', 'exon16_cn exon16_depth exon78_cn exon78_depth' - ) + cn_call = namedtuple("cn_call", "exon16_cn exon16_depth exon78_cn exon78_depth") gmm_exon16 = Gmm() - gmm_exon16.set_gmm_par(gmm_parameter, 'exon1-6') - gcall_exon16 = gmm_exon16.gmm_call(normalized_depth.normalized['exon16']) + gmm_exon16.set_gmm_par(gmm_parameter, "exon1-6") + gcall_exon16 = gmm_exon16.gmm_call(normalized_depth.normalized["exon16"]) gmm_exon78 = Gmm() - gmm_exon78.set_gmm_par(gmm_parameter, 'exon7-8') - gcall_exon78 = gmm_exon78.gmm_call(normalized_depth.normalized['exon78']) + gmm_exon78.set_gmm_par(gmm_parameter, "exon7-8") + gcall_exon78 = gmm_exon78.gmm_call(normalized_depth.normalized["exon78"]) raw_cn_call = cn_call( - gcall_exon16.cn, gcall_exon16.depth_value, - gcall_exon78.cn, gcall_exon78.depth_value + gcall_exon16.cn, + gcall_exon16.depth_value, + gcall_exon78.cn, + gcall_exon78.depth_value, ) # 3. Get SNP ratios smn1_read_count, smn2_read_count = get_supporting_reads( - bam, snp_db.dsnp1, snp_db.dsnp2, snp_db.nchr, snp_db.dindex, reference=reference_fasta + bam, + snp_db.dsnp1, + snp_db.dsnp2, + snp_db.nchr, + snp_db.dindex, + reference=reference_fasta, ) smn1_fraction = get_fraction(smn1_read_count, smn2_read_count) var_ref_count, var_alt_count = get_supporting_reads( - bam, variant_db.dsnp1, variant_db.dsnp2, variant_db.nchr, - variant_db.dindex, reference=reference_fasta + bam, + variant_db.dsnp1, + variant_db.dsnp2, + variant_db.nchr, + variant_db.dindex, + reference=reference_fasta, ) # 4. Call CN of SMN1 and SMN2 final_call = get_smn12_call( - raw_cn_call, smn1_read_count, smn2_read_count, - var_ref_count, var_alt_count, - normalized_depth.mediandepth + raw_cn_call, + smn1_read_count, + smn2_read_count, + var_ref_count, + var_alt_count, + normalized_depth.mediandepth, ) # 5. Prepare final call set sample_call = namedtuple( - 'sample_call', - 'Coverage_MAD Median_depth \ + "sample_call", + "Coverage_MAD Median_depth \ Full_length_CN_raw Total_CN_raw \ SMN1_read_support SMN2_read_support SMN1_fraction \ - g27134TG_REF_count g27134TG_ALT_count' + g27134TG_REF_count g27134TG_ALT_count", ) sample_cn_call = sample_call( - round(normalized_depth.mad, 3), round(normalized_depth.mediandepth, 2), - raw_cn_call.exon78_depth, raw_cn_call.exon16_depth, - smn1_read_count, smn2_read_count, [round(a, 2) for a in smn1_fraction], - var_ref_count, var_alt_count + round(normalized_depth.mad, 3), + round(normalized_depth.mediandepth, 2), + raw_cn_call.exon78_depth, + raw_cn_call.exon16_depth, + smn1_read_count, + smn2_read_count, + [round(a, 2) for a in smn1_fraction], + var_ref_count, + var_alt_count, ) doutput = sample_cn_call._asdict() @@ -142,23 +183,34 @@ def smn_cn_caller( def write_to_tsv(final_output, out_tsv): """Write to tsv output.""" header = [ - 'Sample', 'isSMA', 'isCarrier', 'SMN1_CN', 'SMN2_CN', 'SMN2delta7-8_CN', - 'Total_CN_raw', 'Full_length_CN_raw', 'g.27134T>G_CN', - 'SMN1_CN_raw' + "Sample", + "isSMA", + "isCarrier", + "SMN1_CN", + "SMN2_CN", + "SMN2delta7-8_CN", + "Total_CN_raw", + "Full_length_CN_raw", + "g.27134T>G_CN", + "SMN1_CN_raw", ] - with open(out_tsv, 'w') as tsv_output: - tsv_output.write('\t'.join(header)+'\n') + with open(out_tsv, "w") as tsv_output: + tsv_output.write("\t".join(header) + "\n") for sample_id in final_output: final_call = final_output[sample_id] output_per_sample = [ - sample_id, final_call['isSMA'], final_call['isCarrier'], - final_call['SMN1'], final_call['SMN2'], final_call['SMN2delta78'], - final_call['Total_CN_raw'], final_call['Full_length_CN_raw'], - final_call['g27134TG_CN'], - ','.join([str(a) for a in final_call['SMN1_CN_raw']]) + sample_id, + final_call["isSMA"], + final_call["isCarrier"], + final_call["SMN1"], + final_call["SMN2"], + final_call["SMN2delta78"], + final_call["Total_CN_raw"], + final_call["Full_length_CN_raw"], + final_call["g27134TG_CN"], + ",".join([str(a) for a in final_call["SMN1_CN_raw"]]), ] - tsv_output.write('\t'.join([str(a) - for a in output_per_sample]) + '\n') + tsv_output.write("\t".join([str(a) for a in output_per_sample]) + "\n") def main(): @@ -181,7 +233,7 @@ def main(): for required_file in [region_file, snp_file, variant_file, gmm_file]: if os.path.exists(required_file) == 0: - raise Exception('File %s not found.' % required_file) + raise Exception("File %s not found." % required_file) if os.path.exists(outdir) == 0: os.makedirs(outdir) @@ -190,8 +242,8 @@ def main(): variant_db = get_snp_position(variant_file) gmm_parameter = parse_gmm_file(gmm_file) region_dic = parse_region_file(region_file) - out_json = os.path.join(outdir, prefix + '.json') - out_tsv = os.path.join(outdir, prefix + '.tsv') + out_json = os.path.join(outdir, prefix + ".json") + out_tsv = os.path.join(outdir, prefix + ".tsv") final_output = {} with open(manifest) as read_manifest: for line in read_manifest: @@ -199,38 +251,46 @@ def main(): sample_id = os.path.splitext(os.path.basename(bam_name))[0] count_file = None if path_count_file is not None: - count_file = os.path.join( - path_count_file, sample_id + '_count.txt') + count_file = os.path.join(path_count_file, sample_id + "_count.txt") if count_file is None and os.path.exists(bam_name) == 0: logging.warning( - 'Input alignmet file for sample %s does not exist.', sample_id) + "Input alignmet file for sample %s does not exist.", sample_id + ) elif count_file is not None and os.path.exists(count_file) == 0: logging.warning( - 'Input count file for sample %s does not exist', sample_id) + "Input count file for sample %s does not exist", sample_id + ) else: logging.info( - 'Processing sample %s at %s', sample_id, - datetime.datetime.now() + "Processing sample %s at %s", sample_id, datetime.datetime.now() ) smn_call = smn_cn_caller( - bam_name, region_dic, gmm_parameter, - snp_db, variant_db, threads, count_file, reference_fasta + bam_name, + region_dic, + gmm_parameter, + snp_db, + variant_db, + threads, + count_file, + reference_fasta, ) # Use normalized coverage MAD across stable regions # as a sample QC measure. - if smn_call['Coverage_MAD'] > MAD_THRESHOLD: + if smn_call["Coverage_MAD"] > MAD_THRESHOLD: logging.warning( "Sample %s has uneven coverage. CN calls may be \ - unreliable.", sample_id) + unreliable.", + sample_id, + ) final_output.setdefault(sample_id, smn_call) # Write to json - logging.info('Writing to json at %s', datetime.datetime.now()) - with open(out_json, 'w') as json_output: + logging.info("Writing to json at %s", datetime.datetime.now()) + with open(out_json, "w") as json_output: json.dump(final_output, json_output) # Write to tsv - logging.info('Writing to tsv at %s', datetime.datetime.now()) + logging.info("Writing to tsv at %s", datetime.datetime.now()) write_to_tsv(final_output, out_tsv)