forked from broadinstitute/viral-ngs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
metagenomics.py
executable file
·1453 lines (1233 loc) · 58.7 KB
/
metagenomics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
''' This script contains a number of utilities for metagenomic analyses.
'''
from __future__ import print_function
from __future__ import division
__author__ = "[email protected]"
import argparse
import codecs
import collections
import csv
import gzip
import io
import itertools
import logging
import os.path
from os.path import join
import operator
import queue
import re
import shutil
import sys
import subprocess
import tempfile
import json
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import pysam
import util.cmd
import util.file
import util.misc
import tools.bwa
import tools.diamond
import tools.kraken
import tools.krona
import tools.picard
from util.file import open_or_gzopen
__commands__ = []
log = logging.getLogger(__name__)
class TaxIdError(ValueError):
'''Taxonomy ID couldn't be determined.'''
def maybe_compressed(fn):
fn_gz = fn + '.gz'
if os.path.exists(fn):
return fn
elif os.path.exists(fn_gz):
return fn_gz
else:
raise FileNotFoundError(fn)
class TaxonomyDb(object):
def __init__(
self,
tax_dir=None,
gis=None,
nodes=None,
names=None,
gis_paths=None,
nodes_path=None,
names_path=None,
load_gis=False,
load_nodes=False,
load_names=False
):
if tax_dir:
gis_paths = [maybe_compressed(join(tax_dir, 'gi_taxid_nucl.dmp')),
maybe_compressed(join(tax_dir, 'gi_taxid_prot.dmp'))]
nodes_path = maybe_compressed(join(tax_dir, 'nodes.dmp'))
names_path = maybe_compressed(join(tax_dir, 'names.dmp'))
self.tax_dir = tax_dir
self.gis_paths = gis_paths
self.nodes_path = nodes_path
self.names_path = names_path
if load_gis:
if gis:
self.gis = gis
elif gis_paths:
self.gis = {}
for gi_path in gis_paths:
log.info('Loading taxonomy gis: %s', gi_path)
self.gis.update(self.load_gi_single_dmp(gi_path))
if load_nodes:
if nodes:
self.ranks, self.parents = nodes
elif nodes_path:
log.info('Loading taxonomy nodes: %s', nodes_path)
self.ranks, self.parents = self.load_nodes(nodes_path)
if load_names:
if names:
self.names = names
elif names_path:
log.info('Loading taxonomy names: %s', names_path)
self.names = self.load_names(names_path)
def load_gi_single_dmp(self, dmp_path):
'''Load a gi->taxid dmp file from NCBI taxonomy.'''
gi_array = {}
with open_or_gzopen(dmp_path) as f:
for i, line in enumerate(f):
gi, taxid = line.strip().split('\t')
gi = int(gi)
taxid = int(taxid)
gi_array[gi] = taxid
if (i + 1) % 1000000 == 0:
log.info('Loaded %s gis', i)
return gi_array
def load_names(self, names_db, scientific_only=True):
'''Load the names.dmp file from NCBI taxonomy.'''
if scientific_only:
names = {}
else:
names = collections.defaultdict(list)
for line in open_or_gzopen(names_db):
parts = line.strip().split('|')
taxid = int(parts[0])
name = parts[1].strip()
#unique_name = parts[2].strip()
class_ = parts[3].strip()
if scientific_only:
if class_ == 'scientific name':
names[taxid] = name
else:
names[taxid].append(name)
return names
def load_nodes(self, nodes_db):
'''Load ranks and parents arrays from NCBI taxonomy.'''
ranks = {}
parents = {}
with open_or_gzopen(nodes_db) as f:
for line in f:
parts = line.strip().split('|')
taxid = int(parts[0])
parent_taxid = int(parts[1])
rank = parts[2].strip()
#embl_code = parts[3].strip()
#division_id = parts[4].strip()
parents[taxid] = parent_taxid
ranks[taxid] = rank
return ranks, parents
BlastRecord = collections.namedtuple(
'BlastRecord', [
'query_id', 'subject_id', 'percent_identity', 'aln_length', 'mismatch_count', 'gap_open_count', 'query_start',
'query_end', 'subject_start', 'subject_end', 'e_val', 'bit_score', 'extra'
]
)
def blast_records(f):
'''Yield blast m8 records line by line'''
for line in f:
if line.startswith('#'):
continue
parts = line.strip().split()
for field in range(3, 10):
parts[field] = int(parts[field])
for field in (2, 10, 11):
parts[field] = float(parts[field])
args = parts[:12]
extra = parts[12:]
args.append(extra)
yield BlastRecord(*args)
def paired_query_id(record):
'''Replace paired suffixes in query ids.'''
suffixes = ('/1', '/2')
for suffix in suffixes:
if record.query_id.endswith(suffix):
rec_list = list(record)
rec_list[0] = record.query_id[:-len(suffix)]
return BlastRecord(*rec_list)
return record
def translate_gi_to_tax_id(db, record):
'''Replace gi headers in subject ids to int taxonomy ids.'''
gi = int(record.subject_id.split('|')[1])
tax_id = db.gis[gi]
rec_list = list(record)
rec_list[1] = tax_id
return BlastRecord(*rec_list)
def blast_m8_taxids(record):
return [int(record.subject_id)]
def extract_tax_id(sam1):
'''Replace gi headers in subject ids to int taxonomy ids.'''
parts = sam1.reference_name.split('|')
if parts[0] == 'taxid':
return int(parts[1])
else:
raise TaxIdError(parts)
def sam_lca(db, sam_file, output=None, top_percent=10, unique_only=True):
''' Calculate the LCA taxonomy id for multi-mapped reads in a samfile.
Assumes the sam is sorted by query name. Writes tsv output: query_id \t tax_id.
Args:
db: (TaxonomyDb) Taxonomy db.
sam_file: (path) Sam file.
output: (io) Output file.
top_percent: (float) Only this percent within top hit are used.
unique_only: (bool) If true, only output assignments for unique, mapped reads. If False, set unmapped or duplicate reads as unclassified.
Return:
(collections.Counter) Counter of taxid hits
'''
c = collections.Counter()
with pysam.AlignmentFile(sam_file, 'rb') as sam:
seg_groups = (v for k, v in itertools.groupby(sam, operator.attrgetter('query_name')))
for seg_group in seg_groups:
segs = list(seg_group)
query_name = segs[0].query_name
# 0x4 is unmapped, 0x400 is duplicate
mapped_segs = [seg for seg in segs if seg.flag & 0x4 == 0 and seg.flag & 0x400 == 0]
if unique_only and not mapped_segs:
continue
if mapped_segs:
tax_id = process_sam_hits(db, mapped_segs, top_percent)
if tax_id is None:
log.warn('Query: {} has no valid taxonomy paths.'.format(query_name))
if unique_only:
continue
else:
tax_id = 0
else:
tax_id = 0
if output:
classified = 'C' if tax_id else 'U'
output.write('{}\t{}\t{}\n'.format(classified, query_name, tax_id))
c[tax_id] += 1
return c
def blast_lca(db,
m8_file,
output,
paired=False,
min_bit_score=50,
max_expected_value=0.01,
top_percent=10,):
'''Calculate the LCA taxonomy id for groups of blast hits.
Writes tsv output: query_id \t tax_id
Args:
db: (TaxonomyDb) Taxonomy db.
m8_file: (io) Blast m8 file to read.
output: (io) Output file.
paired: (bool) Whether to count paired suffixes /1,/2 as one group.
min_bit_score: (float) Minimum bit score or discard.
max_expected_value: (float) Maximum e-val or discard.
top_percent: (float) Only this percent within top hit are used.
'''
records = blast_records(m8_file)
records = (r for r in records if r.e_val <= max_expected_value)
records = (r for r in records if r.bit_score >= min_bit_score)
if paired:
records = (paired_query_id(rec) for rec in records)
blast_groups = (v for k, v in itertools.groupby(records, operator.attrgetter('query_id')))
for blast_group in blast_groups:
blast_group = list(blast_group)
tax_id = process_blast_hits(db, blast_group, top_percent)
query_id = blast_group[0].query_id
if not tax_id:
log.debug('Query: {} has no valid taxonomy paths.'.format(query_id))
classified = 'C' if tax_id else 'U'
output.write('{}\t{}\t{}\n'.format(classified, query_id, tax_id))
def process_sam_hits(db, sam_hits, top_percent):
'''Filter groups of blast hits and perform lca.
Args:
db: (TaxonomyDb) Taxonomy db.
sam_hits: []Sam groups of hits.
top_percent: (float) Only consider hits within this percent of top bit score.
Return:
(int) Tax id of LCA.
'''
best_score = max(hit.get_tag('AS') for hit in sam_hits)
cutoff_alignment_score = (100 - top_percent) / 100 * best_score
valid_hits = (hit for hit in sam_hits if hit.get_tag('AS') >= cutoff_alignment_score)
valid_hits = list(valid_hits)
# Sort requires realized list
valid_hits.sort(key=lambda sam1: sam1.get_tag('AS'), reverse=True)
tax_ids = [extract_tax_id(hit) for hit in valid_hits]
return coverage_lca(tax_ids, db.parents)
def process_blast_hits(db, hits, top_percent):
'''Filter groups of blast hits and perform lca.
Args:
db: (TaxonomyDb) Taxonomy db.
hits: []BlastRecord groups of hits.
top_percent: (float) Only consider hits within this percent of top bit score.
Return:
(int) Tax id of LCA.
'''
hits = (translate_gi_to_tax_id(db, hit) for hit in hits)
hits = [hit for hit in hits if hit.subject_id != 0]
if len(hits) == 0:
return
best_score = max(hit.bit_score for hit in hits)
cutoff_bit_score = (100 - top_percent) / 100 * best_score
valid_hits = (hit for hit in hits if hit.bit_score >= cutoff_bit_score)
valid_hits = list(valid_hits)
# Sort requires realized list
valid_hits.sort(key=operator.attrgetter('bit_score'), reverse=True)
if valid_hits:
tax_ids = tuple(itertools.chain(*(blast_m8_taxids(hit) for hit in valid_hits)))
return coverage_lca(tax_ids, db.parents)
def coverage_lca(query_ids, parents, lca_percent=100):
'''Calculate the lca that will cover at least this percent of queries.
Args:
query_ids: []int list of nodes.
parents: []int array of parents.
lca_percent: (float) Cover at least this percent of queries.
Return:
(int) LCA
'''
lca_needed = lca_percent / 100 * len(query_ids)
paths = []
for query_id in query_ids:
path = []
while query_id != 1:
path.append(query_id)
if parents.get(query_id, 0) == 0:
log.warn('Parent for query id: {} missing'.format(query_id))
break
query_id = parents[query_id]
if query_id == 1:
path.append(1)
path = list(reversed(path))
paths.append(path)
if not paths:
return
last_common = 1
max_path_length = max(len(path) for path in paths)
for level in range(max_path_length):
valid_paths = (path for path in paths if len(path) > level)
max_query_id, hits_covered = collections.Counter(path[level] for path in valid_paths).most_common(1)[0]
if hits_covered >= lca_needed:
last_common = max_query_id
else:
break
return last_common
def tree_level_lookup(parents, node, level_cache):
'''Get the node level/depth.
Args:
parents: Array of node parents.
node: Node to get level (root == 1).
level_cache: Cache of previously found levels.
Returns:
(int) level of node
'''
path = []
while True:
level = level_cache.get(node)
if level:
for i, node in enumerate(reversed(path)):
level_cache[node] = level + i + 1
return level + len(path)
path.append(node)
node = parents[node]
def push_up_tree_hits(parents, hits, min_support_percent=None, min_support=None, update_assignments=False):
'''Push up hits on nodes until min support is reached.
Args:
parents: Array of node parents.
hits: Counter of hits on each node.
min_support_percent: Push up hits until each node has
this percent of the sum of all hits.
min_support: Push up hits until each node has this number of hits.
Returns:
(counter) Hits mutated pushed up the tree.
'''
assert min_support_percent or min_support
if update_assignments:
pass
total_hits = sum(hits.values())
if not min_support:
min_support = round(min_support_percent * 0.01 * total_hits)
pq_level = queue.PriorityQueue()
level_cache = {1: 1}
for hit_id, num_hits in hits.items():
if num_hits < min_support:
pq_level.put((-tree_level_lookup(parents, hit_id, level_cache), hit_id))
while not pq_level.empty() > 0:
level, hit_id = pq_level.get()
level = -level
if hits[hit_id] >= min_support:
continue
if hit_id == 1:
del hits[1]
break
parent_hit_id = parents[hit_id]
num_hits = hits[hit_id]
hits[parent_hit_id] += num_hits
# Can't pop directly from hits because hit_id might not be stored in counter
if hit_id in hits:
del hits[hit_id]
if hits[parent_hit_id] < min_support:
pq_level.put((-tree_level_lookup(parents, parent_hit_id, level_cache), parent_hit_id))
return hits
def parents_to_children(parents):
'''Convert an array of parents to lists of children for each parent.
Returns:
(dict[list]) Lists of children
'''
children = collections.defaultdict(list)
for node, parent in parents.items():
if node == 1:
continue
if parent != 0:
children[parent].append(node)
return children
def file_lines(filename):
if filename is not None:
with open(filename) as f:
for line in f:
yield line
def collect_children(children, original_taxids):
'''Collect nodes with all children recursively.'''
taxids = original_taxids
while taxids:
taxid = taxids.pop()
yield taxid
for child_taxid in children[taxid]:
taxids.add(child_taxid)
def collect_parents(parents, taxids):
'''Collect nodes with all parents recursively.'''
# The root taxid node is 1
yield 1
taxids_with_parents = set([1])
for taxid in taxids:
while taxid not in taxids_with_parents:
yield taxid
taxids_with_parents.add(taxid)
taxid = parents[taxid]
def parser_subset_taxonomy(parser=argparse.ArgumentParser()):
parser.add_argument(
"taxDb",
help="Taxonomy database directory (containing nodes.dmp, parents.dmp etc.)",
)
parser.add_argument(
"outputDb",
help="Output taxonomy database directory",
)
parser.add_argument(
"--whitelistTaxids",
help="List of taxids to add to taxonomy (with parents)",
nargs='+', type=int
)
parser.add_argument(
"--whitelistTaxidFile",
help="File containing taxids - one per line - to add to taxonomy with parents.",
)
parser.add_argument(
"--whitelistTreeTaxids",
help="List of taxids to add to taxonomy (with parents and children)",
nargs='+', type=int
)
parser.add_argument(
"--whitelistTreeTaxidFile",
help="File containing taxids - one per line - to add to taxonomy with parents and children.",
)
parser.add_argument(
"--whitelistGiFile",
help="File containing GIs - one per line - to add to taxonomy with nodes.",
)
parser.add_argument(
"--whitelistAccessionFile",
help="File containing accessions - one per line - to add to taxonomy with nodes.",
)
parser.add_argument(
"--skipGi", action='store_true',
help="Skip GI to taxid mapping files"
)
parser.add_argument(
"--skipAccession", action='store_true',
help="Skip accession to taxid mapping files"
)
parser.add_argument(
"--skipDeadAccession", action='store_true',
help="Skip dead accession to taxid mapping files"
)
util.cmd.common_args(parser, (('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, subset_taxonomy, split_args=True)
return parser
def subset_taxonomy(taxDb, outputDb, whitelistTaxids=None, whitelistTaxidFile=None,
whitelistTreeTaxids=None, whitelistTreeTaxidFile=None,
whitelistGiFile=None, whitelistAccessionFile=None,
skipGi=None, skipAccession=None, skipDeadAccession=None,
stripVersion=True):
'''
Generate a subset of the taxonomy db files filtered by the whitelist. The
whitelist taxids indicate specific taxids plus their parents to add to
taxonomy while whitelistTreeTaxids indicate specific taxids plus both
parents and all children taxa. Whitelist GI and accessions can only be
provided in file form and the resulting gi/accession2taxid files will be
filtered to only include those in the whitelist files. Finally, taxids +
parents for the gis/accessions will also be included.
'''
util.file.mkdir_p(os.path.join(outputDb, 'accession2taxid'))
db = TaxonomyDb(tax_dir=taxDb, load_nodes=True)
taxids = set()
if whitelistTaxids is not None:
taxids.update(set(whitelistTaxids))
taxids.update((int(x) for x in file_lines(whitelistTaxidFile)))
tree_taxids = set()
if whitelistTreeTaxids is not None:
tree_taxids.update(set(whitelistTreeTaxids))
taxids.update((int(x) for x in file_lines(whitelistTreeTaxidFile)))
keep_taxids = set(collect_parents(db.parents, taxids))
if tree_taxids:
db.children = parents_to_children(db.parents)
children_taxids = collect_children(db.children, tree_taxids)
keep_taxids.update(children_taxids)
# Taxids kept based on GI or Accession. Get parents afterwards to not pull in all GIs/accessions.
keep_seq_taxids = set()
def filter_file(path, sep='\t', taxid_column=0, gi_column=None, a2t=False, header=False):
input_path = os.path.join(db.tax_dir, path)
output_path = os.path.join(outputDb, path)
input_path = maybe_compressed(input_path)
with open_or_gzopen(input_path, 'rt') as f, \
open_or_gzopen(output_path, 'wt') as out_f:
if header:
out_f.write(next(f))
for line in f:
parts = line.split(sep)
taxid = int(parts[taxid_column])
if gi_column is not None:
gi = int(parts[gi_column])
if gi in gis:
keep_seq_taxids.add(taxid)
out_f.write(line)
continue
if a2t:
accession = parts[accession_column_i]
if stripVersion:
accession = accession.split('.', 1)[0]
if accession in accessions:
keep_seq_taxids.add(taxid)
out_f.write(line)
continue
if taxid in keep_taxids:
out_f.write(line)
if not skipGi:
gis = set(int(x) for x in file_lines(whitelistGiFile))
filter_file('gi_taxid_nucl.dmp', taxid_column=1, gi_column=0)
filter_file('gi_taxid_prot.dmp', taxid_column=1, gi_column=0)
if not skipAccession:
if stripVersion:
accessions = set(x.strip().split('.', 1)[0] for x in file_lines(whitelistAccessionFile))
accession_column_i = 0
else:
accessions = set(file_lines(whitelistAccessionFile))
accession_column_i = 1
acc_dir = os.path.join(db.tax_dir, 'accession2taxid')
acc_paths = []
for fn in os.listdir(acc_dir):
if fn.endswith('.accession2taxid') or fn.endswith('.accession2taxid.gz'):
if skipDeadAccession and fn.startswith('dead_'):
continue
acc_paths.append(os.path.join(acc_dir, fn))
for acc_path in acc_paths:
filter_file(os.path.relpath(acc_path, db.tax_dir), taxid_column=2, header=True, a2t=True)
# Add in taxids found from processing GI/accession
keep_seq_taxids = collect_parents(db.parents, keep_seq_taxids)
keep_taxids.update(keep_seq_taxids)
filter_file('nodes.dmp', sep='|')
filter_file('names.dmp', sep='|')
filter_file('merged.dmp')
filter_file('delnodes.dmp')
__commands__.append(('subset_taxonomy', parser_subset_taxonomy))
def rank_code(rank):
'''Get the short 1 letter rank code for named ranks.'''
if rank == "species":
return "S"
elif rank == "genus":
return "G"
elif rank == "family":
return "F"
elif rank == "order":
return "O"
elif rank == "class":
return "C"
elif rank == "phylum":
return "P"
elif rank == "kingdom":
return "K"
elif rank == "superkingdom":
return "D"
elif rank == "unclassified":
return "U"
else:
return "-"
def taxa_hits_from_tsv(f, taxid_column=2):
'''Return a counter of hits from tsv.'''
c = collections.Counter()
for row in csv.reader(f, delimiter='\t'):
tax_id = int(row[taxid_column - 1])
c[tax_id] += 1
return c
def kraken_dfs_report(db, taxa_hits):
'''Return a kraken compatible DFS report of taxa hits.
Args:
db: (TaxonomyDb) Taxonomy db.
taxa_hits: (collections.Counter) # of hits per tax id.
Return:
[]str lines of the report
'''
db.children = parents_to_children(db.parents)
total_hits = sum(taxa_hits.values())
if total_hits == 0:
return ['\t'.join(['100.00', '0', '0', 'U', '0', 'unclassified'])]
lines = []
kraken_dfs(db, lines, taxa_hits, total_hits, 1, 0)
unclassified_hits = taxa_hits.get(0, 0)
unclassified_hits += taxa_hits.get(-1, 0)
if unclassified_hits > 0:
percent_covered = '%.2f' % (unclassified_hits / total_hits * 100)
lines.append(
'\t'.join([
str(percent_covered), str(unclassified_hits), str(unclassified_hits), 'U', '0', 'unclassified'
])
)
return reversed(lines)
def kraken_dfs(db, lines, taxa_hits, total_hits, taxid, level):
'''Recursively do DFS for number of hits per taxa.'''
cum_hits = num_hits = taxa_hits.get(taxid, 0)
for child_taxid in db.children[taxid]:
cum_hits += kraken_dfs(db, lines, taxa_hits, total_hits, child_taxid, level + 1)
percent_covered = '%.2f' % (cum_hits / total_hits * 100)
rank = rank_code(db.ranks[taxid])
name = db.names[taxid]
if cum_hits > 0:
lines.append('\t'.join([percent_covered, str(cum_hits), str(num_hits), rank, str(taxid), ' ' * level + name]))
return cum_hits
def parser_kraken(parser=argparse.ArgumentParser()):
parser.add_argument('db', help='Kraken database directory.')
parser.add_argument('inBams', nargs='+', help='Input unaligned reads, BAM format.')
parser.add_argument('--outReports', nargs='+', help='Kraken summary report output file. Multiple filenames space separated.')
parser.add_argument('--outReads', nargs='+', help='Kraken per read classification output file. Multiple filenames space separated.')
parser.add_argument('--lockMemory', action='store_true', default=False, help='Lock kraken database in RAM. Requires high ulimit -l.')
parser.add_argument(
'--filterThreshold', default=0.05, type=float, help='Kraken filter threshold (default %(default)s)'
)
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, kraken, split_args=True)
return parser
def kraken(db, inBams, outReports=None, outReads=None, lockMemory=False, filterThreshold=None, threads=None):
'''
Classify reads by taxon using Kraken
'''
assert outReads or outReports, ('Either --outReads or --outReport must be specified.')
kraken_tool = tools.kraken.Kraken()
kraken_tool.pipeline(db, inBams, outReports=outReports, outReads=outReads, lockMemory=lockMemory,
filterThreshold=filterThreshold, numThreads=threads)
__commands__.append(('kraken', parser_kraken))
def parser_krona(parser=argparse.ArgumentParser()):
parser.add_argument('inTsv', help='Input tab delimited file.')
parser.add_argument('db', help='Krona taxonomy database directory.')
parser.add_argument('outHtml', help='Output html report.')
parser.add_argument('--queryColumn', help='Column of query id. (default %(default)s)', type=int, default=2)
parser.add_argument('--taxidColumn', help='Column of taxonomy id. (default %(default)s)', type=int, default=3)
parser.add_argument('--scoreColumn', help='Column of score. (default %(default)s)', type=int, default=None)
parser.add_argument('--magnitudeColumn', help='Column of magnitude. (default %(default)s)', type=int, default=None)
parser.add_argument('--noHits', help='Include wedge for no hits.', action='store_true')
parser.add_argument('--noRank', help='Include no rank assignments.', action='store_true')
util.cmd.common_args(parser, (('loglevel', None), ('version', None)))
util.cmd.attach_main(parser, krona, split_args=True)
return parser
def krona(inTsv, db, outHtml, queryColumn=None, taxidColumn=None, scoreColumn=None, magnitudeColumn=None, noHits=None, noRank=None):
'''
Create an interactive HTML report from a tabular metagenomic report
'''
krona_tool = tools.krona.Krona()
if inTsv.endswith('.gz'):
tmp_tsv = util.file.mkstempfname('.tsv')
with gzip.open(inTsv, 'rb') as f_in:
with open(tmp_tsv, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
to_import = [tmp_tsv]
else:
to_import = [inTsv]
root_name = os.path.basename(inTsv)
krona_tool.import_taxonomy(
db,
to_import,
outHtml,
query_column=queryColumn,
taxid_column=taxidColumn,
score_column=scoreColumn,
magnitude_column=magnitudeColumn,
root_name=root_name,
no_hits=noHits,
no_rank=noRank
)
if inTsv.endswith('.gz'):
# Cleanup tmp .tsv files
for tmp_tsv in to_import:
os.unlink(tmp_tsv)
__commands__.append(('krona', parser_krona))
def parser_diamond(parser=argparse.ArgumentParser()):
parser.add_argument('inBam', help='Input unaligned reads, BAM format.')
parser.add_argument('db', help='Diamond database directory.')
parser.add_argument('taxDb', help='Taxonomy database directory.')
parser.add_argument('outReport', help='Output taxonomy report.')
parser.add_argument('--outReads', help='Output LCA assignments for each read.')
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, diamond, split_args=True)
return parser
def diamond(inBam, db, taxDb, outReport, outReads=None, threads=None):
'''
Classify reads by the taxon of the Lowest Common Ancestor (LCA)
'''
# do not convert this to samtools bam2fq unless we can figure out how to replicate
# the clipping functionality of Picard SamToFastq
picard = tools.picard.SamToFastqTool()
with util.file.fifo(2) as (fastq_pipe, diamond_pipe):
s2fq = picard.execute(
inBam,
fastq_pipe,
interleave=True,
illuminaClipping=True,
JVMmemory=picard.jvmMemDefault,
background=True,
)
diamond_tool = tools.diamond.Diamond()
taxonmap = join(taxDb, 'accession2taxid', 'prot.accession2taxid.gz')
taxonnodes = join(taxDb, 'nodes.dmp')
rutils = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'read_utils.py')
cmd = '{read_utils} join_paired_fastq --outFormat fasta /dev/stdout {fastq}'.format(
read_utils=rutils, fastq=fastq_pipe)
cmd += ' | {} blastx --out {output} --outfmt 102 --sallseqid'.format(
diamond_tool.install_and_get_path(), output=diamond_pipe)
cmd += ' --threads {threads}'.format(threads=util.misc.sanitize_thread_count(threads))
cmd += ' --db {db} --taxonmap {taxonmap} --taxonnodes {taxonnodes}'.format(
db=db,
taxonmap=taxonmap,
taxonnodes=taxonnodes)
if outReads is not None:
# Interstitial save of stdout to output file
cmd += ' | tee >(pigz --best > {out})'.format(out=outReads)
diamond_ps = subprocess.Popen(cmd, shell=True, executable='/bin/bash')
tax_db = TaxonomyDb(tax_dir=taxDb, load_names=True, load_nodes=True)
with open(diamond_pipe) as lca_p:
hits = taxa_hits_from_tsv(lca_p)
with open(outReport, 'w') as f:
for line in kraken_dfs_report(tax_db, hits):
print(line, file=f)
s2fq.wait()
assert s2fq.returncode == 0
diamond_ps.wait()
assert diamond_ps.returncode == 0
__commands__.append(('diamond', parser_diamond))
def parser_diamond_fasta(parser=argparse.ArgumentParser()):
parser.add_argument('inFasta', help='Input sequences, FASTA format, optionally gzip compressed.')
parser.add_argument('db', help='Diamond database file.')
parser.add_argument('taxDb', help='Taxonomy database directory.')
parser.add_argument('outFasta', help='Output sequences, same as inFasta, with taxid|###| prepended to each sequence identifier.')
parser.add_argument('--memLimitGb', type=float, default=None, help='approximate memory usage in GB')
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, diamond_fasta, split_args=True)
return parser
def diamond_fasta(inFasta, db, taxDb, outFasta, threads=None, memLimitGb=None):
'''
Classify fasta sequences by the taxon of the Lowest Common Ancestor (LCA)
'''
with util.file.tmp_dir() as tmp_dir:
# run diamond blastx on fasta sequences
cmd = [tools.diamond.Diamond().install_and_get_path(), 'blastx',
'-q', inFasta,
'--db', db,
'--outfmt', '102', # tsv: query name, taxid of LCA, e-value
'--salltitles',# to recover the entire fasta sequence name
'--sensitive', # this is necessary for longer reads or contigs
'--algo', '1', # for small query files
'--threads', str(util.misc.sanitize_thread_count(threads)),
'--taxonmap', os.path.join(taxDb, 'accession2taxid', 'prot.accession2taxid.gz'),
'--taxonnodes', os.path.join(taxDb, 'nodes.dmp'),
'--tmpdir', tmp_dir,
]
if memLimitGb:
cmd.extend(['--block-size', str(round(memLimitGb / 5.0, 1))])
log.debug(' '.join(cmd))
diamond_p = subprocess.Popen(cmd, stdout=subprocess.PIPE)
# read the output report and load into an in-memory map
# of sequence ID -> tax ID (there shouldn't be that many sequences)
seq_to_tax = {}
for line in diamond_p.stdout:
row = line.decode('UTF-8').rstrip('\n\r').split('\t')
tax = row[1] if row[1] != '0' else '32644' # diamond returns 0 if unclassified, but the proper taxID for that is 32644
seq_to_tax[row[0]] = tax
if diamond_p.poll():
raise subprocess.CalledProcessError(diamond_p.returncode, cmd)
# copy inFasta to outFasta while prepending taxid|###| to all sequence names
log.debug("transforming {} to {}".format(inFasta, outFasta))
with util.file.open_or_gzopen(inFasta, 'rt') as inf:
with util.file.open_or_gzopen(outFasta, 'wt') as outf:
for seq in Bio.SeqIO.parse(inf, 'fasta'):
taxid = seq_to_tax.get(seq.id, '32644') # default to "unclassified"
for text_line in util.file.fastaMaker([(
'|'.join('taxid', taxid, seq.id),
str(seq.seq))]):
outf.write(text_line)
__commands__.append(('diamond_fasta', parser_diamond_fasta))
def parser_build_diamond_db(parser=argparse.ArgumentParser()):
parser.add_argument('protein_fastas', nargs='+', help='Input protein fasta files')
parser.add_argument('db', help='Output Diamond database file')
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, build_diamond_db, split_args=True)
return parser
def build_diamond_db(protein_fastas, db, threads=None):
tool.diamond.Diamond().build(db, protein_fastas, options={'threads':str(util.misc.sanitize_thread_count(threads))})
__commands__.append(('build_diamond_db', parser_build_diamond_db))
def parser_align_rna_metagenomics(parser=argparse.ArgumentParser()):
parser.add_argument('inBam', help='Input unaligned reads, BAM format.')
parser.add_argument('db', help='Bwa index prefix.')
parser.add_argument('taxDb', help='Taxonomy database directory.')
parser.add_argument('outReport', help='Output taxonomy report.')
parser.add_argument('--dupeReport', help='Generate report including duplicates.')
parser.add_argument(
'--sensitive',
dest='sensitive',
action="store_true",
help='Use sensitive instead of default BWA mem options.'
)
parser.add_argument('--outBam', help='Output aligned, indexed BAM file. Default is to write to temp.')
parser.add_argument('--outReads', help='Output LCA assignments for each read.')
parser.add_argument('--dupeReads', help='Output LCA assignments for each read including duplicates.')
parser.add_argument(
'--JVMmemory',
default=tools.picard.PicardTools.jvmMemDefault,
help='JVM virtual memory size (default: %(default)s)'
)
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, align_rna_metagenomics, split_args=True)
return parser
def align_rna_metagenomics(
inBam,
db,
taxDb,
outReport,
dupeReport=None,
outBam=None,
dupeReads=None,
outReads=None,
sensitive=None,
JVMmemory=None,
threads=None,
picardOptions=None,
):
'''
Align to metagenomics bwa index, mark duplicates, and generate LCA report
'''
picardOptions = picardOptions if picardOptions else []
bwa = tools.bwa.Bwa()
samtools = tools.samtools.SamtoolsTool()
bwa_opts = ['-a']
if sensitive:
bwa_opts += '-k 12 -A 1 -B 1 -O 1 -E 1'.split()
# TODO: Use bwa.mem's min_score_to_filter argument to decrease false
# positives in the output. Currently, it works by summing the alignment
# score across all alignments output by bwa for each query (reads in a
# pair, supplementary, and secondary alignments). This is not reasonable
# for reads with secondary alignments because it will be easier for those
# reads/queries to exceed the threshold given by the value of the argument.
# In this context, bwa is called using '-a' as an option and its output
# will likely include many secondary alignments. One option is to add
# another argument to bwa.mem, similar to min_score_to_filter, that sets a
# threshold on the alignment score of output alignments but only filters on
# a per-alignment level (i.e., not by summing alignment scores across all
# alignments for each query).
aln_bam = util.file.mkstempfname('.bam')
bwa.mem(inBam, db, aln_bam, options=bwa_opts)
tax_db = TaxonomyDb(tax_dir=taxDb, load_names=True, load_nodes=True)
if dupeReport:
aln_bam_sorted = util.file.mkstempfname('.align_namesorted.bam')
samtools.sort(aln_bam, aln_bam_sorted, args=['-n'], threads=threads)