forked from indrops/indrops
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathquantify_umifm_from_alignments.py
executable file
·456 lines (371 loc) · 20.1 KB
/
quantify_umifm_from_alignments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
import pysam
from collections import defaultdict
try:
import cPickle as pickle
except:
import pickle
from copy import copy
from itertools import combinations
from numpy import memmap
# from indrops import load_indexed_memmapped_array
def print_to_log(msg):
"""
Wrapper to eventually log in smart way, instead of using 'print()'
"""
sys.stderr.write(str(msg)+'\n')
def quant(args):
#Convert arg to more explicit names
multiple_alignment_threshold = args.m
distance_from_tx_end = args.d
split_ambiguities = args.split_ambi
ambig_count_threshold = args.u
using_mixed_ref = args.mixed_ref
#Assume that references are named 'transcript_name|gene_name'
tx_to_gid = lambda tx: tx.split('|')[1]
umis_for_geneset = defaultdict(set)
sam_input = pysam.AlignmentFile("-", "r" )
# Tuple containing lengths of reference sequences
ref_lengths = copy(sam_input.lengths)
# Bam file to be generated
if args.bam:
sam_output = pysam.AlignmentFile(args.bam, "wb", template=sam_input)
# Load cache of low complexity regions
soft_masked_regions = None
if args.soft_masked_regions:
low_complexity_regions = pickle.load(args.soft_masked_regions)
soft_masked_regions = defaultdict(set)
for tx, regions in low_complexity_regions.items():
if regions:
soft_masked_regions[tx] = set.union(*[set(range(a,b)) for a,b in regions])
soft_masked_fraction_threshold = 0.5
def process_read_alignments(alignments):
"""input: one-element list of a single alignment from a bam file
corresponding to a given barcode"""
# Remove any alignments that aren't supported by a certain number of non-poly A bases.
dependent_on_polyA_tail = False
if args.min_non_polyA > 0:
polyA_independent_alignments = []
for a in alignments:
start_of_polyA = ref_lengths[a.reference_id] - args.polyA
if a.reference_end < start_of_polyA:
# The alignment doesn't overlap the polyA tail.
polyA_independent_alignments.append(a)
else:
non_polyA_part = start_of_polyA - a.reference_start
if non_polyA_part > args.min_non_polyA:
polyA_independent_alignments.append(a)
dependent_on_polyA_tail = len(polyA_independent_alignments) == 0
alignments = polyA_independent_alignments
# Remove any alignments that are mostly to low complexity regions
if soft_masked_regions:
for a in alignments:
tx_id = sam_input.getrname(a.reference_id)
soft_masked_bases = soft_masked_regions[tx_id].intersection(set(range(a.reference_start, a.reference_end)))
soft_masked_fraction = float(len(soft_masked_bases))/(a.reference_end - a.reference_start)
a.setTag('XC', '%.2f' % soft_masked_fraction)
alignments = [a for a in alignments if float(a.opt('XC')) < soft_masked_fraction_threshold]
# We need to obtain Transcript IDs in terms of reference names (Transcrupt_ID|Gene_ID)
# as opposed to the arbitrary 'a.reference_id' number
tx_ids = [sam_input.getrname(a.reference_id) for a in alignments]
#Map to Gene IDs
g_ids = [tx_to_gid(tx_id) for tx_id in tx_ids]
# finally remove all copies to get a comprehensive unique list of genes
# found for this barcode
genes = set(g_ids)
# Does the alignment map to multiple genes or just one?
unique = True
# Was the alignment non-unique, but then rescued to being unique?
rescued_non_unique = False
# Even after rescue, was the alignment mapping to more than M genes?
failed_m_threshold = False
# The same read could align to transcripts from different genes.
if 1 < len(genes):
unique = False
close_alignments = [a for a in alignments if (ref_lengths[a.reference_id] - a.reference_end)<distance_from_tx_end]
close_tx_ids = [sam_input.getrname(a.reference_id) for a in close_alignments]
close_g_ids = [tx_to_gid(tx_id) for tx_id in close_tx_ids]
close_genes = set(close_g_ids)
if 0 < len(close_genes) < len(genes):
alignments = close_alignments
genes = close_genes
if len(close_genes) == 1:
rescued_non_unique = True
#Choose 1 alignment per gene, that we will write to the output BAM.
chosen_alignments = {}
keep_read = 0 < len(genes) <= multiple_alignment_threshold
# We need different logic if we are using a mixed organism reference
if using_mixed_ref:
refs = set(g.split(':')[1] for g in genes)
keep_read = (len(refs) == 1) and (0 < len(genes) <= multiple_alignment_threshold)
if keep_read:
for gene in genes:
gene_alignments = [a for a in alignments if tx_to_gid(sam_input.getrname(a.reference_id)) == gene]
chosen_alignment = sorted(gene_alignments, key=lambda a: ref_lengths[a.reference_id], reverse=True)[0]
chosen_alignments[gene] = chosen_alignment
else:
failed_m_threshold = True
read_filter_status = (unique, rescued_non_unique, failed_m_threshold, dependent_on_polyA_tail)
return chosen_alignments, read_filter_status
# --------------------------
# Process SAM input
# (we load everything into memory, so if a single barcode has truly very deep sequencing, we could get into trouble
# --------------------------
uniq_count = 0
rescued_count = 0
non_uniq_count = 0
failed_m_count = 0
not_aligned_count = 0
current_read = None
read_alignments = []
reads_by_umi = defaultdict(dict)
rev = 0
non_rev = 0
for alignment in sam_input:
#Skip alignments that failed to align...
if alignment.reference_id == -1:
not_aligned_count += 1
# if args.bam:
# sam_output.write(alignment)
continue
# The If statements detects that Bowtie is giving info about a different read,
# so let's process the last one before proceeding
if not current_read == alignment.query_name:
#Check that our read has any alignments
if read_alignments:
chosen_alignments, processing_stats = process_read_alignments(read_alignments)
if chosen_alignments:
split_name = current_read.split(':')
if len(split_name) == 2:
umi = split_name[1] #Old Adrian Format
elif len(split_name) == 3:
umi = split_name[1] #Adrian format
else:
umi = split_name[4] #Old Allon format
seq = read_alignments[0].seq
reads_by_umi[umi][alignment.query_name] = chosen_alignments
uniq_count += processing_stats[0]
non_uniq_count += not(processing_stats[0] or processing_stats[1] or processing_stats[2])
rescued_count += processing_stats[1]
failed_m_count += processing_stats[2]
# We reset the current read info
current_read = alignment.query_name
read_alignments = []
read_alignments.append(alignment)
# Only runs if preceding for loop terminated without break
# This is not very DRY...
else:
if read_alignments:
chosen_alignments, processing_stats = process_read_alignments(read_alignments)
if chosen_alignments:
split_name = current_read.split(':')
if len(split_name) == 2:
umi = split_name[1] #Old Adrian Format
elif len(split_name) == 3:
umi = split_name[1] #Adrian format
else:
umi = split_name[4] #Allon format
seq = read_alignments[0].seq
reads_by_umi[umi][alignment.query_name] = chosen_alignments
uniq_count += processing_stats[0]
non_uniq_count += not(processing_stats[0] or processing_stats[1] or processing_stats[2])
rescued_count += processing_stats[1]
failed_m_count += processing_stats[2]
# -----------------------------
# Time to filter based on UMIs
# (and output)
# --------------------------
umi_counts = defaultdict(float)
ambig_umi_counts = defaultdict(float)
ambig_gene_partners = defaultdict(set)
ambig_clique_count = defaultdict(list)
oversequencing = []
distance_from_transcript_end = []
temp_sam_output = []
for umi, umi_reads in reads_by_umi.items():
#Invert the (read, gene) mapping
aligns_by_gene = defaultdict(lambda: defaultdict(set))
for read, read_genes in umi_reads.items():
for gene, alignment in read_genes.items():
aligns_by_gene[gene][len(read_genes)].add(alignment)
#Pick the best alignment for each gene:
# - least other alignments
# - highest alignment quality
# - longest read
best_alignment_for_gene = {}
for gene, alignments in aligns_by_gene.items():
# min_ambiguity_alignments = alignments[min(alignments.keys())]
# max_qual = max(a.mapq for a in min_ambiguity_alignments)
# max_qual_alignments = filter(lambda a: a.mapq==max_qual, min_ambiguity_alignments)
# best_alignment_for_gene[gene] = max(max_qual_alignments, key=lambda a: a.qlen)
best_alignment_for_gene[gene] = alignments[min(alignments.keys())]
# Compute hitting set
g0 = set.union(*(set(gs) for gs in umi_reads.values())) #Union of the gene sets of all reads from that UMI
r0 = set(umi_reads.keys())
gene_read_mapping = dict()
for g in g0:
for r in r0:
gene_read_mapping[(g, r)] = float(g in umi_reads[r])/(len(umi_reads[r])**2)
target_genes = dict()
#Keys are genes, values are the number of ambiguous partner of each gene
while len(r0) > 0:
#For each gene in g0, compute how many reads point ot it
gene_contrib = dict((gi, sum(gene_read_mapping[(gi, r)] for r in r0)) for gi in g0)
#Maximum value of how many reads poitn to any gene
max_contrib = max(gene_contrib.values())
#Gene with max contrib
max_contrib_genes = filter(lambda g: gene_contrib[g]==max_contrib, gene_contrib.keys())
#Pick a gene among those with the highest value. Which doesn't matter until the last step
g = max_contrib_genes[0]
read_count_for_umifm = 0
umifm_assigned_unambiguously = False
for r in copy(r0): #Take a copy of r0 doesn't change as we iterate through it
if gene_read_mapping[(g, r)]: #Remove any reads from r0 that contributed to the picked gene.
r0.remove(r)
#Count how many reads we are removing (this is the degree of over-sequencing)
read_count_for_umifm += 1
# umifm_reads.append(r)
# If we had equivalent picks,
# and their gene contrib value is now 0
# they were ambiguity partners
if len(max_contrib_genes) > 1:
# Update the gene contribs based on the new r0, but on the 'old' g0.
# That is why we remove g from g0 after this step only
gene_contrib = dict((gi, sum(gene_read_mapping[(gi, r)] for r in r0)) for gi in g0)
ambig_partners = filter(lambda g: gene_contrib[g]==0, max_contrib_genes)
#Ambig partners will often be a 1-element set. That's ok.
#Then it will be equivalent to "target_genes[g] = 1."
if len(ambig_partners) <= ambig_count_threshold:
if len(ambig_partners) == 1:
umifm_assigned_unambiguously = True
ambig_clique_count[0].append(umi)
for g_alt in ambig_partners:
ambig_gene_partners[g_alt].add(frozenset(ambig_partners))
target_genes[g_alt] = float(len(ambig_partners))
if len(ambig_partners) != 1:
ambig_clique_count[len(ambig_partners)].append(umi)
else:
umifm_assigned_unambiguously = True
target_genes[g] = 1.
ambig_clique_count[1].append(umi)
#Remove g here, so that g is part of the updated gene_contrib, when necessary
g0.remove(g)
#For each target gene, output the best alignment
#and record umi count
for gene, ambigs in target_genes.items():
supporting_alignments = best_alignment_for_gene[gene]
if args.bam:
for alignment_for_output in best_alignment_for_gene[gene]:
# Add the following tags to aligned reads:
# XB - Library Name
# XB - Barcode Name
# XU - UMI sequence
# XO - Oversequencing number (how many reads with the same UMI are assigned to this gene)
# YG - Gene identity
# YK - Start of the alignment, relative to the transcriptome
# YL - End of the alignment, relative to the transcriptome
# YT - Length of alignment transcript
alignment_for_output.setTag('XL', args.library)
alignment_for_output.setTag('XB', args.barcode)
alignment_for_output.setTag('XU', umi)
alignment_for_output.setTag('XO', len(supporting_alignments))
alignment_for_output.setTag('YG', gene)
alignment_for_output.setTag('YK', int(alignment_for_output.pos))
alignment_for_output.setTag('YL', int(alignment_for_output.reference_end))
alignment_for_output.setTag('YT', int(ref_lengths[alignment.reference_id]))
temp_sam_output.append(alignment_for_output)
split_between = ambigs if split_ambiguities else 1.
umi_counts[gene] += 1./split_between
ambig_umi_counts[gene] += (1./split_between if ambigs>1 else 0)
#Output the counts per gene
all_genes = set()
for ref in sam_input.references:
gene = ref.split('|')[1]
all_genes.add(gene)
sorted_all_genes = sorted(all_genes)
sorted_metric_columns = ['total_input_reads','single_alignment','rescued_single_alignment','non_unique_less_than_m','non_unique_more_than_m','not_aligned','unambiguous_umifm','umifm_degrees_of_ambiguity_2','umifm_degrees_of_ambiguity_3','umifm_degrees_of_ambiguity_>3']
output_umi_counts = [umi_counts[gene] for gene in sorted_all_genes]
if args.write_header:
args.counts.write('\t'.join(['barcode'] + sorted_all_genes) + '\n')
args.ambigs.write('\t'.join(['barcode'] + sorted_all_genes) + '\n')
args.metrics.write('\t'.join(["Barcode","Reads","Reads with unique alignment","Reads with unique alignment within shorter distance of 3'-end","Reads with less than `m` multiple alignments","Reads with more than than `m` multiple alignments","Reads with no alignments", "UMIFM","Ambig UMIFM (between 2 genes)","Ambig UMIFM (between 3 genes)","Ambig UMIFM (between more than 3 genes)",]) + '\n')
if sum(output_umi_counts) >= args.min_counts:
ignored = False
args.counts.write('\t'.join([args.barcode] + [str(int(u)) for u in output_umi_counts]) + '\n')
# Output sam data
if args.bam:
for alignment in temp_sam_output:
sam_output.write(alignment)
sam_output.close()
# Output ambig data
output_ambig_counts = [ambig_umi_counts[gene] for gene in sorted_all_genes]
if sum(output_ambig_counts) > 0:
args.ambigs.write('\t'.join([args.barcode] + [str(int(u)) for u in output_ambig_counts]) + '\n')
output_ambig_partners = {}
for gene in sorted_all_genes:
if ambig_gene_partners[gene]:
gene_partners = frozenset.union(*ambig_gene_partners[gene])-frozenset((gene,))
if gene_partners:
output_ambig_partners[gene] = gene_partners
args.ambig_partners.write(args.barcode + '\t'+ str(output_ambig_partners) + '\n')
else:
ignored = True
with open(args.counts.name + '.ignored', 'a') as f:
f.write(args.barcode + '\n')
args.counts.close()
args.ambigs.close()
args.ambig_partners.close()
#Output the fixing metrics
total_input_reads = uniq_count + rescued_count + non_uniq_count + failed_m_count + not_aligned_count
metrics_data = {
'total_input_reads': total_input_reads,
'single_alignment': uniq_count,
'rescued_single_alignment': rescued_count,
'non_unique_less_than_m': non_uniq_count,
'non_unique_more_than_m': failed_m_count,
'not_aligned': not_aligned_count,
'unambiguous_umifm' : 0,
'umifm_degrees_of_ambiguity_2' : 0,
'umifm_degrees_of_ambiguity_3' : 0,
'umifm_degrees_of_ambiguity_>3' : 0,
}
for k, v in ambig_clique_count.items():
if k == 0:
metrics_data['unambiguous_umifm'] += len(v)
elif k == 1:
metrics_data['unambiguous_umifm'] += len(v)
elif k == 2:
metrics_data['umifm_degrees_of_ambiguity_2'] += len(v)
elif k == 3:
metrics_data['umifm_degrees_of_ambiguity_3'] += len(v)
elif k > 3:
metrics_data['umifm_degrees_of_ambiguity_>3'] += len(v)
args.metrics.write('\t'.join([args.barcode] + [str(metrics_data[c]) for c in sorted_metric_columns]) + '\n')
log_output_line = "{0:<8d}{1:<8d}{2:<10d}".format(total_input_reads, metrics_data['unambiguous_umifm'],
metrics_data['umifm_degrees_of_ambiguity_2']+metrics_data['umifm_degrees_of_ambiguity_3']+metrics_data['umifm_degrees_of_ambiguity_>3'])
if ignored:
log_output_line += ' [Ignored from output]'
print_to_log(log_output_line)
if __name__=="__main__":
import sys, argparse
parser = argparse.ArgumentParser()
parser.add_argument('-m', help='Ignore reads with more than M alignments, after filtering on distance from transcript end.', type=int, default=4)
parser.add_argument('-u', help='Ignore counts from UMI that should be split among more than U genes.', type=int, default=4)
parser.add_argument('-d', help='Maximal distance from transcript end.', type=int, default=525)
parser.add_argument('--polyA', help='Length of polyA tail in reference transcriptome.', type=int, default=5)
parser.add_argument('--split_ambi', help="If umi is assigned to m genes, add 1/m to each gene's count (instead of 1)", action='store_true', default=False)
parser.add_argument('--mixed_ref', help="Reference is mixed, with records named 'gene:ref', should only keep reads that align to one ref.", action='store_true', default=False)
parser.add_argument('--min_non_polyA', type=int, default=0)
# parser.add_argument('--counts', type=argparse.FileType('w'))
# parser.add_argument('--metrics', type=argparse.FileType('w'))
parser.add_argument('--counts', type=argparse.FileType('a'))
parser.add_argument('--metrics', type=argparse.FileType('a'))
parser.add_argument('--ambigs', type=argparse.FileType('a'))
parser.add_argument('--ambig-partners', type=argparse.FileType('a'))
parser.add_argument('--barcode', type=str)
parser.add_argument('--library', type=str, default='')
parser.add_argument('--min-counts', type=int, default=0)
parser.add_argument('--write-header', action='store_true')
parser.add_argument('--bam', type=str, nargs='?', default='')
parser.add_argument('--soft-masked-regions', type=argparse.FileType('r'), nargs='?')
args = parser.parse_args()
quant(args)