forked from cbg-ethz/cojac
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcooc-mutbamscan
executable file
·212 lines (180 loc) · 7.43 KB
/
cooc-mutbamscan
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#!/usr/bin/env python3
import numpy as np
import pandas as pd
import sys
import os
#import re
import argparse
import csv
import json
import yaml
import gzip
import pysam
# parse command line
argparser = argparse.ArgumentParser(fromfile_prefix_chars='@', # support pass BAM list in file instead of parameters
description="scan amplicon (covered by long read pairs) for mutation cooccurrence",
epilog="@listfile can be used to pass a long list of parameters (e.g.: a large number of BAMs) in a file instead of command line")
inputgroup = argparser.add_mutually_exclusive_group(required=True)
inputgroup.add_argument('-s', '--samples', metavar='TSV',
type=str, dest='samples', help="V-pipe samples list tsv")
inputgroup.add_argument('-a', '--alignments', metavar='BAM/CRAM', nargs='+',
dest='alignments', help="alignment files")
argparser.add_argument('-p', '--prefix', metavar='PATH', required=False, default="working/samples",
type=str, dest='prefix', help="V-pipe work directory prefix for where to look at align files when using TSV samples list")
argparser.add_argument('-r', '--reference', metavar='REFID', required=False, default='NC_045512.2',
type=str, dest='rq_chr', help="reference to look for in alignment files")
argparser.add_argument('-j', '--json', metavar='JSON', required=False, default=None,
type=str, dest='json', help="output results to as JSON file")
argparser.add_argument('-y', '--yaml', metavar='YAML', required=False, default=None,
type=str, dest='yaml', help="output results to as yaml file")
argparser.add_argument('-t', '--tsv', metavar='TSV', required=False, default=None,
type=str, dest='tsv', help="output results to as (raw) tsv file")
argparser.add_argument('-d', '--dump',
action="store_true", dest='dump', help = "dump the python object to the terminal")
args = argparser.parse_args()
if args.samples is not None:
assert os.path.isfile(args.samples), f"cannot find sample file {args.samples}"
else:
for align in args.alignments:
assert os.path.isfile(align), f"cannot find alginment file {align}"
def test_read(read, mut_dict):
"""
test if mutations listed in mut_dict are present in the pysam read
returns a list with:
found_site: list of site present in the read (no matter content)
found_mut: list of those position which have the mutations variant
"""
# WARNING pysam is 0-based! (see here: https://pysam.readthedocs.io/en/latest/faq.html#pysam-coordinates-are-wrong )
# 1. check which mutation' sites are in range of that read
found_site = [p for p,m in mut_dict.items() if read.reference_start <= (p-1) <= (read.reference_end-len(m))]
if not len(found_site):
return (None, None) # sites aren't present no point checking variants
# 2. of those sites, check which content mutations' variants
read_dict = dict(zip(read.get_reference_positions(full_length=True), read.query_sequence))
found_mut = []
for p in found_site:
l=len(mut_dict[p]) #lenght
# base present ?
srch=[(p-1+i) in read_dict for i in range(l)]
if all(srch): # all positions found!
# check if it's the expected mutation(s)
if all([read_dict[p-1+i] == mut_dict[p][i] for i in range(l)]):
found_mut.append(p)
elif not any(srch): # none position found! (entire deletion)
# check if we're hunting for a string of deletions (-)
if '-' * l == mut_dict[p]:
found_mut.append(p)
# TODO give some thoughs about partial deletions
if len(found_mut): # found mutation as sites
return (found_site, found_mut)
else: # sites present, but no mutation found
return (found_site, None)
# scan an amplicon for a specific set of mutations
def scanamplicon(read_iter, mut_dict):
# TODO inefficient, could be streamed (with an accumulator for pairs)
reads = dict()
for read in read_iter:
name=str(read.query_name)
R='R1' if read.is_read1 else 'R2'
if name in reads:
reads[name][R] = read
else:
reads[name] = { R: read }
print("amplion:", len(reads))
# tally the mutation sites and the presence of variant in there accross all reads
# TODO merge with above
all_muts=[]
all_sites=[]
for rds in reads:
val=reads[rds]
site_out = []
mut_out = []
for s in val:
R=val[s]
(t_pos, t_read) = test_read(R, mut_dict)
if t_pos is not None:
site_out.extend(t_pos)
if t_read is not None:
mut_out.extend(t_read)
if (len(site_out)):
all_sites.append(site_out)
if (len(mut_out)):
all_muts.append(mut_out)
sites_cnt=np.unique([len(set(mut)) for mut in all_sites], return_counts=True)
print("sites:", len(all_sites), sites_cnt)
muts_cnt=np.unique([len(set(sit)) for sit in all_muts], return_counts=True)
print("muts:", len(all_muts), muts_cnt)
# look at last column only
return {
"sites": dict(zip(sites_cnt[0].tolist(),sites_cnt[1].tolist())) if len(all_sites) else {},
"muts": dict(zip(muts_cnt[0].tolist(),muts_cnt[1].tolist())) if len(all_muts) else {}
}
# TODO better naming
def findbam(prefix, batch, sample):
'''function to find bam from prefix, batch and sample'''
alnfname=os.path.join(prefix,sample,batch,'alignments/REF_aln.bam')
assert os.path.isfile(alnfname), f"cannot find alginment file {alnfname}"
return alnfname
def scanbam(alnfname, amplicons):
'''scan a bamfile found at alnfname'''
amp_results={}
with pysam.AlignmentFile(alnfname, "rb") as alnfile:
for amp_name,amp in amplicons.items():
(rq_b,rq_e,mut_dict)=amp
# we need at least 2 to compute co-occurrence
if len(mut_dict) < 1: # HACK 2:
continue
print(f"amplicon_{amp_name}", rq_b, rq_e, mut_dict, sep='\t', end='\t')
amplicon_iter = alnfile.fetch(rq_chr, rq_b, rq_e)
amp_results[amp_name] = scanamplicon(amplicon_iter, mut_dict)
return amp_results
# hardcode
# TODO compute from mutations
amplicons={
'72_UK': [21718, 21988, {21765: '------', 21991: '---'}],
'78_UK': [23502, 23810, {23604: 'A', 23709: 'T'}],
'92_UK': [27827, 28102, {27972: 'T', 28048: 'T', 28111: 'G'}],
'93_UK': [28147, 28414, {28111: 'G', 28280: 'CTA'}],
'76_SA': [22879, 23142, {23012: 'A', 23063: 'T'}],
# '76_UK': [22879, 23142, {23063: 'T',23271: 'A'}],
'77_EU': [23194, 23464, {23403: 'G'}],
}
# TODO autoguess reference from alignment
rq_chr=args.rq_chr # e.g.: 'NC_045512.2'
# loop for if samples are given through the TSV list
table={}
if args.samples is not None:
i=0
with open(args.samples,'rt',encoding='utf-8', newline='') as tf: # this file has the same content as the original experiment
for r in csv.reader(tf, delimiter='\t'): #dialect='excel-tab'):
sample,batch=r[:2]
#i+=1
#if (i==3):
#break;
print(sample)
alnfname = findbam(args.prefix, batch, sample)
table[sample]=scanbam(alnfname, amplicons)
# loop for if samples are given through -a option
# this option can also de used to dispatch per sample jobx on the cluster
else:
for alnfname in args.alignments:
sample = alnfname # HACK use the whole BAM file as sample name
table[sample]=scanbam(alnfname, amplicons)
#
# dumps, for being able to take it from here
#
if (args.dump) or (not (args.json or args.yaml)):
print(table)
if args.json:
with open(args.json, 'wt') as jf:
json.dump(obj=table, fp=jf)
if args.yaml:
with open(args.yaml, 'wt') as yf:
print(yaml.dump(table, sort_keys=False), file=yf)
# raw data into pands.DataFrame
if args.tsv:
raw_table_df=pd.DataFrame.from_dict(data={i: {(j,k): table[i][j][k] for j in table[i] for k in table[i][j]} for i in table},
orient='index')
#with pd.option_context('display.max_rows', None): #, 'display.max_columns', None):
#print(raw_table_df)
raw_table_df.to_csv(args.tsv, sep="\t", compression={'method':'infer'})