From 951ca562571cee887a9454b9b97f67fb2591df66 Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Wed, 1 Nov 2023 09:28:22 -0700 Subject: [PATCH] Add CIGAR tools module --- micall/core/contig_stitcher.py | 133 ++++++++ micall/tests/test_cigar_tools.py | 278 ++++++++++++++++ micall/tests/test_contig_stitcher.py | 61 ++++ micall/utils/cigar_tools.py | 475 +++++++++++++++++++++++++++ micall/utils/consensus_aligner.py | 3 +- 5 files changed, 949 insertions(+), 1 deletion(-) create mode 100644 micall/core/contig_stitcher.py create mode 100644 micall/tests/test_cigar_tools.py create mode 100644 micall/tests/test_contig_stitcher.py create mode 100644 micall/utils/cigar_tools.py diff --git a/micall/core/contig_stitcher.py b/micall/core/contig_stitcher.py new file mode 100644 index 000000000..9deffe179 --- /dev/null +++ b/micall/core/contig_stitcher.py @@ -0,0 +1,133 @@ +import argparse +import logging +import os +from typing import Iterable, Optional +from collections import Counter +from csv import DictWriter, DictReader +from dataclasses import dataclass +from datetime import datetime +from glob import glob +from io import StringIO +from itertools import chain +from operator import itemgetter +from shutil import rmtree +from subprocess import run, PIPE, CalledProcessError, STDOUT +from tempfile import mkdtemp +from mappy import Aligner + +from micall.utils.cigar_tools import connect_cigar_hits, CigarHit + + +@dataclass +class Contig: + name: str + seq: str + + +@dataclass +class GenotypedContig(Contig): + ref_name: str + ref_seq: str + matched_fraction: Optional[float] # Approximated overall concordance between `seq` and `ref_seq`. + + +@dataclass +class AlignedContig: + contig: GenotypedContig + alignment: CigarHit + + def reference_slice(self, r_st, r_en): + """ Narrows this alignment to a more specific region in reference. """ + + alignment = self.alignment.reference_slice(r_st, r_en) + contig = self.contig + return AlignedContig(contig, alignment) + + +def align_to_reference(contig: GenotypedContig): + aligner = Aligner(seq=contig.ref_seq, bw=500, bw_long=500, preset='map-ont') + alignments = list(aligner.map(contig.seq)) + if not alignments: + return AlignedContig(contig=contig, alignment=None) + + hits_array = [CigarHit(x.cigar, x.r_st, x.r_en - 1, x.q_st, x.q_en - 1) for x in alignments] + single_cigar_hit = connect_cigar_hits(hits_array) + return AlignedContig(contig=contig, alignment=single_cigar_hit) + + +def intervals_overlap(x, y): + """ Check if two intervals (x0, x1) and (y0, y1) overlap. """ + + if x[0] > y[1] or y[0] > x[1]: + return False + else: + return True + + +def interval_contains(x, y): + """ Check if interval (x0, x1) contains interval (y0, y1). """ + + return x[0] <= y[0] and x[1] >= y[1] + + +def find_all_overlapping_contigs(self, aligned_contigs): + for other in aligned_contigs: + if self.contig.ref_name != other.contig.ref_name: + continue + + if self.alignment.overlaps(other.alignment): + yield other + + +def find_overlapping_contig(self, aligned_contigs): + every = find_all_overlapping_contigs(self, aligned_contigs) + return max(chain(every, [None]), + key=lambda other: other.alignment.r_ei - other.alignment.r_st if other else 0) + + +def calculate_overlap(left, right): + left_seq_0 = left.contig.seq + right_seq_0 = right.contig.seq + + left_mapping = left.alignment.coordinate_mapping + right_mapping = right.alignment.coordinate_mapping + + overlap_interval = (right.alignment.r_st, left.alignment.r_ei) + # left_overlap_seq = [left.contig.ref_seq[i] for i in ] + + left_overlap = left.reference_slice(overlap_interval[0], overlap_interval[1]) + right_overlap = right.reference_slice(overlap_interval[0], overlap_interval[1]) + + + + +def stitch_contigs(contigs: Iterable[GenotypedContig]): + aligned = list(map(align_to_reference, contigs)) + + # Contigs that did not align do not need any more processing + stitched = [x.contig for x in aligned if x.alignment is None] + aligned = [x for x in aligned if x.alignment is not None] + + # Going left-to-right through aligned parts. + aligned = list(sorted(aligned, key=lambda x: x.alignment.r_st)) + while aligned: + current = aligned.pop(0) + + # Filter out all contigs that are contained within the current one. + # TODO: actually filter out if covered by multiple contigs + aligned = [x for x in aligned if not \ + interval_contains((current.alignment.r_st, current.alignment.r_ei), + (x.alignment.r_st, x.alignment.r_ei))] + + # Find overlap. If there isn't one - we are done with the current contig. + overlapping_contig = find_overlapping_contig(current, aligned) + if not overlapping_contig: + stitched.append(current.contig) + continue + + # Get overlaping regions + overlap = calculate_overlap(current, overlapping_contig) + + # aligned.append(combined) + + return stitched diff --git a/micall/tests/test_cigar_tools.py b/micall/tests/test_cigar_tools.py new file mode 100644 index 000000000..05eff8db8 --- /dev/null +++ b/micall/tests/test_cigar_tools.py @@ -0,0 +1,278 @@ +import pytest +from typing import List, Tuple + +from micall.utils.consensus_aligner import CigarActions +from micall.utils.cigar_tools import Cigar, CigarHit + + +cigar_mapping_cases: List[Tuple[Cigar, 'mapping', 'closest_mapping']] = [ + # Simple cases + ('3M', {0: 0, 1: 1, 2: 2}, + {0: 0, 1: 1, 2: 2}), + ('1M1D1M', {0: 0, 2: 1}, + {0: 0, 1: 0, 2: 1}), + ('1M1I1M', {0: 0, 1: 2}, + {0: 0, 1: 2}), + ('2M2D2M', {0: 0, 1: 1, 4: 2, 5: 3}, + {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3}), + ('2M2I2M', {0: 0, 1: 1, 2: 4, 3: 5}, + {0: 0, 1: 1, 2: 4, 3: 5}), + ('3M1D3M', {0: 0, 1: 1, 2: 2, 4: 3, 5: 4, 6: 5}, + {0: 0, 1: 1, 2: 2, 3: 2, 4: 3, 5: 4, 6: 5}), + ('3M1I3M', {0: 0, 1: 1, 2: 2, 3: 4, 4: 5, 5: 6}, + {0: 0, 1: 1, 2: 2, 3: 4, 4: 5, 5: 6}), + ('7M1I3M', {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 8, 8: 9, 9: 10}, + {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 8, 8: 9, 9: 10}), + ('5M2D4M', {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 7: 5, 8: 6, 9: 7, 10: 8}, + {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 5, 7: 5, 8: 6, 9: 7, 10: 8}), + ('5M3I4M', {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 8, 6: 9, 7: 10, 8: 11}, + {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 8, 6: 9, 7: 10, 8: 11}), + ('1M1D', {0: 0}, + {0: 0}), + ('1M1I', {0: 0}, + {0: 0}), + ('1I1M', {0: 1}, + {0: 1}), + ('1D1M', {1: 0}, + {1: 0}), + + # Multiple deletions and insertions + ('2M2D2M2I2M', {0: 0, 1: 1, 4: 2, 5: 3, 6: 6, 7: 7}, + {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 6, 7: 7}), + ('2M2I2M2D2M', {0: 0, 1: 1, 2: 4, 3: 5, 6: 6, 7: 7}, + {0: 0, 1: 1, 2: 4, 3: 5, 4: 5, 5: 6, 6: 6, 7: 7}), + ('2=1X2N1N2=1H2S', {0: 0, 1: 1, 2: 2, 6: 3, 7: 4}, + {0: 0, 1: 1, 2: 2, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4}), + ('2M2D2M2I2M', {0: 0, 1: 1, 4: 2, 5: 3, 6: 6, 7: 7}, + {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 6, 7: 7}), + ('3=1X2N1N2=1H2S', {0: 0, 1: 1, 2: 2, 3: 3, 7: 4, 8: 5}, + {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 3, 6: 4, 7: 4, 8: 5}), + + # Edge cases + ('', {}, ValueError()), + ('12I', {}, ValueError()), + ('12D', {}, ValueError()), +] + + +@pytest.mark.parametrize("cigar_str, expected_mapping", [(x[0], x[1]) for x in cigar_mapping_cases]) +def test_cigar_to_coordinate_mapping(cigar_str, expected_mapping): + mapping = Cigar.coerce(cigar_str).coordinate_mapping + + assert expected_mapping == mapping.ref_to_query_d + assert expected_mapping == {i: mapping.ref_to_query(i) for i in mapping.ref_to_query_d} + + +@pytest.mark.parametrize("cigar_str", [x[0] for x in cigar_mapping_cases]) +def test_cigar_to_coordinate_bijection_property(cigar_str): + inverse = lambda d: {v: k for k, v in d.items()} + + mapping = Cigar.coerce(cigar_str).coordinate_mapping + + assert mapping.query_to_ref_d == inverse(mapping.ref_to_query_d) + assert mapping.ref_to_query_d == inverse(mapping.query_to_ref_d) + assert mapping.ref_to_query_d == inverse(inverse(mapping.ref_to_query_d)) + assert mapping.query_to_ref_d == inverse(inverse(mapping.query_to_ref_d)) + + +@pytest.mark.parametrize("cigar_str, expected_closest_mapping", [(x[0], x[2]) for x in cigar_mapping_cases]) +def test_cigar_to_closest_coordinate_mapping(cigar_str, expected_closest_mapping): + mapping = Cigar.coerce(cigar_str).coordinate_mapping + + if isinstance(expected_closest_mapping, Exception): + with pytest.raises(type(expected_closest_mapping)): + mapping.ref_to_closest_query(0) + + else: + fullrange = {i: mapping.ref_to_closest_query(i) \ + for i in range(min(mapping.ref_to_query_d), 1 + max(mapping.ref_to_query_d))} + assert expected_closest_mapping == fullrange + + +@pytest.mark.parametrize("cigar_str, expected_mapping", [(x[0], x[1]) for x in cigar_mapping_cases]) +def test_cigar_hit_to_coordinate_mapping(cigar_str, expected_mapping): + cigar = Cigar.coerce(cigar_str) + hit = CigarHit(cigar, r_st=5, r_ei=(5 + cigar.ref_length - 1), q_st=7, q_ei=(7 + cigar.query_length - 1)) + mapping = hit.coordinate_mapping + + # Coordinates are translated by q_st and r_st. + expected_mapping = {k + hit.r_st: v + hit.q_st for (k, v) in expected_mapping.items()} + assert mapping.ref_to_query(0) == None + assert mapping.query_to_ref(0) == None + assert expected_mapping \ + == {i: mapping.ref_to_query(i) for i in mapping.ref_to_query_d} + + +@pytest.mark.parametrize("cigar_str, expected_closest_mapping", [(x[0], x[2]) for x in cigar_mapping_cases]) +def test_cigar_hit_to_coordinate_closest_mapping(cigar_str, expected_closest_mapping): + cigar = Cigar.coerce(cigar_str) + hit = CigarHit(cigar, r_st=5, r_ei=(5 + cigar.ref_length - 1), q_st=7, q_ei=(7 + cigar.query_length - 1)) + mapping = hit.coordinate_mapping + + if isinstance(expected_closest_mapping, Exception): + with pytest.raises(type(expected_closest_mapping)): + mapping.ref_to_closest_query(0) + + else: + # Coordinates are translated by q_st and r_st. + expected_closest_mapping = {k + hit.r_st: v + hit.q_st for (k, v) in expected_closest_mapping.items()} + fullrange = {i: mapping.ref_to_closest_query(i) \ + for i in range(min(mapping.ref_to_query_d), 1 + max(mapping.ref_to_query_d))} + assert expected_closest_mapping == fullrange + + +def test_invalid_operation_in_cigar_string(): + with pytest.raises(ValueError): + Cigar.coerce('3M1Z3M') # Z operation is not implemented + + +def test_invalid_operation_in_cigar_list(): + with pytest.raises(ValueError): + Cigar.coerce([(3, 42)]) # Operation code "42" does not exist + + +def test_invalid_cigar_string(): + with pytest.raises(ValueError): + Cigar.coerce('3MMMMMM3M') # Too many Ms + with pytest.raises(ValueError): + Cigar.coerce('3') # Not enough Ms + + +cigar_hit_ref_cut_cases = [ + # Trivial cases + (CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9), 3.5, + [CigarHit('3M', r_st=1, r_ei=3, q_st=1, q_ei=3), + CigarHit('6M', r_st=4, r_ei=9, q_st=4, q_ei=9)]), + + (CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9), 4.5, + [CigarHit('4M', r_st=1, r_ei=4, q_st=1, q_ei=4), + CigarHit('5M', r_st=5, r_ei=9, q_st=5, q_ei=9)]), + + (CigarHit('9M', r_st=0, r_ei=8, q_st=0, q_ei=8), 3.5, + [CigarHit('4M', r_st=0, r_ei=3, q_st=0, q_ei=3), + CigarHit('5M', r_st=4, r_ei=8, q_st=4, q_ei=8)]), + + # Simple cases + (CigarHit('9M9D9M', r_st=1, r_ei=27, q_st=1, q_ei=18), 3.5, + [CigarHit('3M', r_st=1, r_ei=3, q_st=1, q_ei=3), + CigarHit('6M9D9M', r_st=4, r_ei=27, q_st=4, q_ei=18)]), + + (CigarHit('9M9D9M', r_st=1, r_ei=27, q_st=1, q_ei=18), 20.5, + [CigarHit('9M9D2M', r_st=1, r_ei=20, q_st=1, q_ei=11), + CigarHit('7M', r_st=21, r_ei=27, q_st=12, q_ei=18)]), + + (CigarHit('9M9I9M', r_st=1, r_ei=18, q_st=1, q_ei=27), 3.5, + [CigarHit('3M', r_st=1, r_ei=3, q_st=1, q_ei=3), + CigarHit('6M9I9M', r_st=4, r_ei=18, q_st=4, q_ei=27)]), + + (CigarHit('9M9I9M', r_st=1, r_ei=18, q_st=1, q_ei=27), 13.5 or 27/2, + [CigarHit('9M9I4M', r_st=1, r_ei=13, q_st=1, q_ei=22), + CigarHit('5M', r_st=14, r_ei=18, q_st=23, q_ei=27)]), + + # Ambigous cases + (CigarHit('9M9D9M', r_st=1, r_ei=27, q_st=1, q_ei=18), 13.5 or 27/2, + [CigarHit('9M4D', r_st=1, r_ei=13, q_st=1, q_ei=9), + CigarHit('5D9M', r_st=14, r_ei=27, q_st=10, q_ei=18)]), + + (CigarHit('9M9I9M', r_st=1, r_ei=18, q_st=1, q_ei=27), 9.2, + [CigarHit('9M1I', r_st=1, r_ei=9, q_st=1, q_ei=10), + CigarHit('8I9M', r_st=10, r_ei=18, q_st=11, q_ei=27)]), + + # Edge cases + (CigarHit('9M9I9M', r_st=1, r_ei=18, q_st=1, q_ei=27), 9.5, # no middlepoint + [CigarHit('9M5I', r_st=1, r_ei=9, q_st=1, q_ei=14), + CigarHit('4I9M', r_st=10, r_ei=18, q_st=15, q_ei=27)]), + + (CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9), 8.5, # one is singleton + [CigarHit('8M', r_st=1, r_ei=8, q_st=1, q_ei=8), + CigarHit('1M', r_st=9, r_ei=9, q_st=9, q_ei=9)]), + + (CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9), 9.5, # one is empty + [CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9), + CigarHit('', r_st=10, r_ei=9, q_st=10, q_ei=9)]), + + (CigarHit('7M', r_st=3, r_ei=9, q_st=3, q_ei=9), 2.5, # one is empty + [CigarHit('', r_st=3, r_ei=2, q_st=3, q_ei=2), + CigarHit('7M', r_st=3, r_ei=9, q_st=3, q_ei=9)]), + + (CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9), 0.5, # one is empty around 0 + [CigarHit('', r_st=1, r_ei=0, q_st=1, q_ei=0), + CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9)]), + + (CigarHit('9M', r_st=0, r_ei=8, q_st=0, q_ei=8), -0.5, # another one is empty and negative + [CigarHit('', r_st=0, r_ei=-1, q_st=0, q_ei=-1), + CigarHit('9M', r_st=0, r_ei=8, q_st=0, q_ei=8)]), + + (CigarHit('9D', r_st=1, r_ei=9, q_st=1, q_ei=0), 3.5, + [CigarHit('3D', r_st=1, r_ei=3, q_st=1, q_ei=0), + CigarHit('6D', r_st=4, r_ei=9, q_st=1, q_ei=0)]), + + (CigarHit('9D', r_st=0, r_ei=8, q_st=0, q_ei=-1), -0.5, + [CigarHit('', r_st=0, r_ei=-1, q_st=0, q_ei=-1), + CigarHit('9D', r_st=0, r_ei=8, q_st=0, q_ei=-1)]), + + (CigarHit('2=1X2N1N2=1H2S', r_st=1, r_ei=8, q_st=1, q_ei=7), 3.5, + [CigarHit('2=1X', r_st=1, r_ei=3, q_st=1, q_ei=3), + CigarHit('3N2=1H2S', r_st=4, r_ei=8, q_st=4, q_ei=7)]), + + # Negative cases + (CigarHit('9M9I9M', r_st=1, r_ei=18, q_st=1, q_ei=27), 20.5, + IndexError("20.5 is bigger than reference (18)")), + + (CigarHit('', r_st=2, r_ei=1, q_st=2, q_ei=1), 2.5, + IndexError("Empty string cannot be cut")), + + (CigarHit('', r_st=2, r_ei=1, q_st=2, q_ei=1), 1.5, + IndexError("Empty string cannot be cut")), + + (CigarHit('9I', r_st=1, r_ei=0, q_st=1, q_ei=9), 3.5, + IndexError("Out of reference bounds")), + + (CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9), 4, + ValueError("Cut point must not be an integer")), + +] + +@pytest.mark.parametrize('hit, cut_point, expected_result', cigar_hit_ref_cut_cases) +def test_cigar_hit_ref_cut(hit, cut_point, expected_result): + if isinstance(expected_result, Exception): + with pytest.raises(type(expected_result)): + hit.cut_reference(cut_point) + + else: + expected_left, expected_right = expected_result + left, right = hit.cut_reference(cut_point) + assert expected_left == left + assert expected_right == right + + +@pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases + if not isinstance(x[2], Exception)]) +def test_cigar_hit_ref_cut_add_prop(hit, cut_point): + left, right = hit.cut_reference(cut_point) + assert left + right == hit == right + left + + +@pytest.mark.parametrize("reference_seq, query_seq, cigar, expected_reference, expected_query", [ + ('ACTG', 'ACTG', '4M', 'ACTG', 'ACTG'), + ('ACTG', '', '4D', 'ACTG', '----'), + ('', 'ACTG', '4I', '----', 'ACTG'), + ('ACTGAC', 'ACAC', '2M2D2M', 'ACTGAC', 'AC--AC'), + ('ACAC', 'ACTGAC', '2M2I2M', 'AC--AC', 'ACTGAC'), + ('GCTATGGGAA', 'GCTATGGGAA', '5M3D2M', 'GCTATGGGAA', 'GCTAT---GG'), + ('ACTG', 'ACTG', '2M99H77P2M', 'ACTG', 'ACTG'), # Ignores non-consuming operations. +]) +def test_cigar_to_msa(reference_seq, query_seq, cigar, expected_reference, expected_query): + assert Cigar.coerce(cigar).to_msa(reference_seq, query_seq) \ + == (expected_reference, expected_query) + + +@pytest.mark.parametrize("cigar, reference_seq, query_seq", [ + ('10M', 'A' * 3, 'A' * 10), # reference is shorter than CIGAR + ('10M', 'A' * 10, 'A' * 3), # query is shorter than CIGAR + ('10D', 'A' * 3, 'A' * 3), + ('10I', 'A' * 3, 'A' * 3), +]) +def test_illigal_cigar_to_msa(cigar, reference_seq, query_seq): + with pytest.raises(ValueError): + Cigar.coerce(cigar).to_msa(reference_seq, query_seq) diff --git a/micall/tests/test_contig_stitcher.py b/micall/tests/test_contig_stitcher.py new file mode 100644 index 000000000..5001eeb2e --- /dev/null +++ b/micall/tests/test_contig_stitcher.py @@ -0,0 +1,61 @@ +import pytest +from micall.core.contig_stitcher import stitch_contigs, GenotypedContig + + +def test_1(): + contigs = [ + GenotypedContig(name='a', + seq='ACTGACTG' * 100, + ref_name='testref', + ref_seq='ACTGACTG' * 100, + matched_fraction=1.0, + ), + ] + + result = stitch_contigs(contigs) + assert result == contigs + + +def test_2(): + ref_seq = 'A' * 100 + + contigs = [ + GenotypedContig(name='a', + seq=ref_seq, + ref_name='testref', + ref_seq=ref_seq, + matched_fraction=0.5, + ), + GenotypedContig(name='b', + seq='C' * 100, + ref_name='testref', + ref_seq=ref_seq, + matched_fraction=0.5, + ), + ] + + result = stitch_contigs(contigs) + assert sorted(result, key=lambda x: x.name) \ + == sorted(contigs, key=lambda x: x.name) + + +def test_3(): + ref_seq = 'A' * 100 + 'C' * 100 + + contigs = [ + GenotypedContig(name='a', + seq='A' * 50 + 'C' * 20, + ref_name='testref', + ref_seq=ref_seq, + matched_fraction=0.5, + ), + GenotypedContig(name='b', + seq='A' * 20 + 'C' * 50, + ref_name='testref', + ref_seq=ref_seq, + matched_fraction=0.5, + ), + ] + + result = stitch_contigs(contigs) + assert False diff --git a/micall/utils/cigar_tools.py b/micall/utils/cigar_tools.py new file mode 100644 index 000000000..cffab829d --- /dev/null +++ b/micall/utils/cigar_tools.py @@ -0,0 +1,475 @@ +""" +Module for handling CIGAR strings and related alignment formats. +""" + +from math import ceil, floor +import re +from typing import List, Tuple, Iterable, Optional +from collections import OrderedDict +from dataclasses import dataclass +import itertools +import copy +from functools import cached_property + +from micall.utils.consensus_aligner import CigarActions + + +CIGAR_OP_MAPPING = { + 'M': CigarActions.MATCH, + 'I': CigarActions.INSERT, + 'D': CigarActions.DELETE, + 'N': CigarActions.SKIPPED, + 'S': CigarActions.SOFT_CLIPPED, + 'H': CigarActions.HARD_CLIPPED, + 'P': CigarActions.PADDING, + '=': CigarActions.SEQ_MATCH, + 'X': CigarActions.MISMATCH, +} + + +def parse_cigar_operation(operation: str) -> CigarActions: + if operation in CIGAR_OP_MAPPING: + return CIGAR_OP_MAPPING[operation] + else: + raise ValueError(f"Unexpected CIGAR action: {operation}.") + + +def cigar_operation_to_str(op: CigarActions) -> str: + return [k for (k, v) in CIGAR_OP_MAPPING.items() if v == op][0] + + +class CoordinateMapping: + def __init__(self): + self.query_to_ref_d = {} + self.ref_to_query_d = {} + self.ref_to_op_d = {} + self.query_to_op_d = {} + + + def extend(self, + ref_index: Optional[int], + query_index: Optional[int], + op_index: Optional[int]): + + if ref_index is not None and query_index is not None: + self.ref_to_query_d[ref_index] = query_index + self.query_to_ref_d[query_index] = ref_index + + if op_index is not None: + if ref_index is not None: + self.ref_to_op_d[ref_index] = op_index + if query_index is not None: + self.query_to_op_d[query_index] = op_index + + + def ref_to_query(self, index) -> Optional[int]: + return self.ref_to_query_d.get(index, None) + + + def query_to_ref(self, index) -> Optional[int]: + return self.query_to_ref_d.get(index, None) + + + @staticmethod + def _find_closest_key(mapping: dict, index: int) -> int: + return min(mapping, key=lambda k: abs(mapping[k] - index)) + + + def ref_to_closest_query(self, index) -> int: + return CoordinateMapping._find_closest_key(self.query_to_ref_d, index) + + + def query_to_closest_ref(self, index) -> int: + return CoordinateMapping._find_closest_key(self.ref_to_query_d, index) + + + def ref_to_leftsup_query(self, index) -> Optional[int]: + left_neihbourhood = (k for (k, v) in self.query_to_ref_d.items() if v <= index) + return max(left_neihbourhood, default=None) + + + def ref_to_rightinf_query(self, index) -> Optional[int]: + right_neihbourhood = (k for (k, v) in self.query_to_ref_d.items() if index <= v) + return min(right_neihbourhood, default=None) + + + def ref_or_query_to_op(self, ref_index: int, query_index: int, conflict): + r = self.ref_to_op_d.get(ref_index, None) + q = self.query_to_op_d.get(query_index, None) + if r is not None and q is not None: + return conflict(r, q) + + return r if q is None else q + + + def translate_coordinates(self, reference_offset: int, query_offset: int) -> 'CoordinateMapping': + ret = CoordinateMapping() + + ret.ref_to_query_d = {k + reference_offset: v + query_offset for (k, v) in self.ref_to_query_d.items()} + ret.query_to_ref_d = {k + query_offset: v + reference_offset for (k, v) in self.query_to_ref_d.items()} + ret.ref_to_op_d = {k + reference_offset: v for (k, v) in self.ref_to_op_d.items()} + ret.query_to_op_d = {k + query_offset: v for (k, v) in self.query_to_op_d.items()} + + return ret + + +class Cigar(list): + """ + A CIGAR string represents a read alignment against a reference sequence. + It is a run-length encoded sequence of alignment operations listed below: + + M: Alignment match (can be a sequence match or mismatch) + D: Deletion from the reference + I: Insertion to the reference + S: Soft clip on the read (ignored region, not aligned but present in the read) + H: Hard clip on the read (ignored region, not present in the read) + N: Skipped region from the reference + P: Padding (silent deletion from padded reference, not applicable for our case) + =: Sequence match + X: Sequence mismatch + + CIGAR strings are defined in the SAM specification + (https://samtools.github.io/hts-specs/SAMv1.pdf). + """ + + + def __init__(self, cigar_lst): + super().__init__([]) + for x in cigar_lst: self.append(x) + + + @staticmethod + def coerce(obj): + if isinstance(obj, Cigar): + return obj + + if isinstance(obj, str): + return Cigar.parse(obj) + + if isinstance(obj, list): + return Cigar(obj) + + raise TypeError(f"Cannot coerce {obj!r} to CIGAR string.") + + + @staticmethod + def parse(string): + data = [] + while string: + match = re.match(r'([0-9]+)([^0-9])', string) + if match: + num, operation = match.groups() + data.append([int(num), parse_cigar_operation(operation)]) + string = string[match.end():] + else: + raise ValueError(f"Invalid CIGAR string. Invalid part: {string[:20]}") + + return Cigar(data) + + + def append(self, item: Tuple[int, CigarActions]): + # Type checking + if not isinstance(item, list) and not isinstance(item, tuple): + raise ValueError(f"Invalid CIGAR list: {item!r} is not a tuple.") + if len(item) != 2: + raise ValueError(f"Invalid CIGAR list: {item!r} is has a bad length.") + + num, operation = item + if isinstance(operation, int): + operation = CigarActions(operation) + if not isinstance(num, int) or not isinstance(operation, CigarActions): + raise ValueError(f"Invalid CIGAR list: {item!r} is not a number/operation tuple.") + + # Normalization + if num == 0: + return + + if self: + last_num, last_operation = self[-1] + if operation == last_operation: + self[-1] = (last_num + num, operation) + return + + super().append((num, operation)) + + + def iterate_operations(self) -> Iterable[CigarActions]: + for num, operation in self: + for _ in range(num): + yield operation + + + def iterate_operations_with_pointers(self) -> Iterable[Tuple[CigarActions, Optional[int], Optional[int]]]: + ref_pointer = 0 + query_pointer = 0 + + for operation in self.iterate_operations(): + if operation in (CigarActions.MATCH, CigarActions.SEQ_MATCH, CigarActions.MISMATCH): + yield (operation, ref_pointer, query_pointer) + query_pointer += 1 + ref_pointer += 1 + + elif operation in (CigarActions.INSERT, CigarActions.SOFT_CLIPPED): + yield (operation, None, query_pointer) + query_pointer += 1 + + elif operation in (CigarActions.DELETE, CigarActions.SKIPPED): + yield (operation, ref_pointer, None) + ref_pointer += 1 + + else: + yield (operation, None, None) + + + @cached_property + def query_length(self): + return max((query_pointer + 1 if query_pointer is not None else 0 for (_, _, query_pointer) + in self.iterate_operations_with_pointers()), + default=0) + + + @cached_property + def ref_length(self): + return max((ref_pointer + 1 if ref_pointer is not None else 0 for (_, ref_pointer, _) + in self.iterate_operations_with_pointers()), + default=0) + + + def slice_operations(self, start_inclusive, end_noninclusive) -> 'Cigar': + return Cigar([(1, op) for op in self.iterate_operations()] + [start_inclusive:end_noninclusive]) + + + @cached_property + def coordinate_mapping(self) -> CoordinateMapping: + """ + Convert a CIGAR string to coordinate mapping representing a reference-to-query and query-to-reference coordinate mappings. + TODO: describe the domains and holes. + + :param cigar: a CIGAR string. + + :return: Lists of integers representing the mappings of coordinates from the reference + sequence to the query sequence, and back. + """ + + mapping = CoordinateMapping() + + for op_pointer, (operation, ref_pointer, query_pointer) in enumerate(self.iterate_operations_with_pointers()): + mapping.extend(ref_pointer, + query_pointer, + op_pointer) + + return mapping + + + def to_msa(self, reference_seq, query_seq) -> Tuple[str, str]: + reference_msa = '' + query_msa = '' + + for operation, ref_pointer, query_pointer in self.iterate_operations_with_pointers(): + if ref_pointer is None and query_pointer is None: + continue + + try: + if ref_pointer is not None: + reference_msa += reference_seq[ref_pointer] + else: + reference_msa += '-' + + if query_pointer is not None: + query_msa += query_seq[query_pointer] + else: + query_msa += '-' + + except IndexError: + raise ValueError("CIGAR string corresponds to a larger match than either reference or query.") + + return reference_msa, query_msa + + + def __repr__(self): + return f'Cigar({str(self)!r})' + + + def __str__(self): + """ Inverse of Cigar.parse """ + return ''.join('{}{}'.format(num, cigar_operation_to_str(op)) for num, op in self) + + +@dataclass +class CigarHit: + cigar: Cigar + r_st: int + r_ei: int # inclusive + q_st: int + q_ei: int # inclusive + + + def __post_init__(self): + self.cigar = Cigar.coerce(self.cigar) + + if self.r_len != self.cigar.ref_length: + raise ValueError(f"CIGAR string maps {self.cigar.ref_length}" + f" reference positions, but CIGAR hit range is {self.r_len}") + + if self.q_len != self.cigar.query_length: + raise ValueError(f"CIGAR string maps {self.cigar.query_length}" + f" query positions, but CIGAR hit range is {self.q_len}") + + + @property + def r_len(self): + return self.r_ei + 1 - self.r_st + + + @property + def q_len(self): + return self.q_ei + 1 - self.q_st + + + def overlaps(self, other) -> bool: + def intervals_overlap(x, y): + """ Check if two intervals [x0, x1] and [y0, y1] overlap. """ + return x[0] <= y[1] and x[1] >= y[0] + + return intervals_overlap((self.r_st, self.r_ei), (other.r_st, other.r_ei)) \ + or intervals_overlap((self.q_st, self.q_ei), (other.q_st, other.q_ei)) + + + def __add__(self, other): + """ + Inserts deletions/insertions between self and other, + then ajusts boundaries appropriately. + """ + + if self.overlaps(other): + raise ValueError("Cannot combine overlapping CIGAR hits") + + if (self.r_st, self.r_ei) < (other.r_st, other.r_ei): + # Note: in cases where one CigarHit is empty, comparing only by a single coordiate is not sufficient. + left = self + right = other + else: + left = other + right = self + + cigar = left.cigar \ + + Cigar.coerce([(right.r_st - left.r_ei - 1, CigarActions.DELETE)]) \ + + Cigar.coerce([(right.q_st - left.q_ei - 1, CigarActions.INSERT)]) \ + + right.cigar + + return CigarHit(cigar=cigar, + r_st=left.r_st, + r_ei=right.r_ei, + q_st=left.q_st, + q_ei=right.q_ei, + ) + + + def _slice(self, r_st, r_ei, q_st, q_ei) -> 'CigarHit': + mapping = self.coordinate_mapping + + o_st = mapping.ref_or_query_to_op(r_st, q_st, min) + o_ei = mapping.ref_or_query_to_op(r_ei, q_ei, max) + if o_st is None or o_ei is None: + cigar = Cigar([]) + else: + cigar = self.cigar.slice_operations(o_st, o_ei + 1) + + return CigarHit(cigar=cigar, + r_st = r_st, + r_ei = r_ei, + q_st = q_st, + q_ei = q_ei, + ) + + + def _ref_cut_to_query_cut(self, cut_point: float): + mapping = self.coordinate_mapping + + left_query_cut_point = mapping.ref_to_leftsup_query(floor(cut_point)) + right_query_cut_point = mapping.ref_to_rightinf_query(ceil(cut_point)) + + if left_query_cut_point is None: + return self.q_st - 0.1 + if right_query_cut_point is None: + return self.q_ei + 0.1 + + lerp = lambda start, end, t: (1 - t) * start + t * end + query_cut_point = lerp(left_query_cut_point, right_query_cut_point, + cut_point - floor(cut_point)) + + if float(query_cut_point).is_integer(): + # Disambiguate to the right. + query_cut_point += 0.1 / (self.r_st + self.r_ei + self.q_st + self.q_ei) + + return query_cut_point + + + def cut_reference(self, cut_point: float) -> 'CigarHit': + """ + Splits alignment in two parts such that cut_point is in between. + Guarantees that the two parts do not share any elements, + and that no element is lost. + """ + + if float(cut_point).is_integer(): + raise ValueError("Cut accepts fractions, not integers") + + if self.r_len == 0 or \ + not (self.r_st - 1 < cut_point < self.r_ei + 1): + raise IndexError("Cut point out of reference bounds") + + query_cut_point = self._ref_cut_to_query_cut(cut_point) + assert (self.q_st - 1 <= query_cut_point <= self.q_ei + 1) + + left = self._slice(self.r_st, floor(cut_point), + self.q_st, floor(query_cut_point)) + right = self._slice(ceil(cut_point), self.r_ei, + ceil(query_cut_point), self.q_ei) + + return left, right + + + @cached_property + def coordinate_mapping(self) -> CoordinateMapping: + return self.cigar.coordinate_mapping.translate_coordinates(self.r_st, self.q_st) + + + def to_msa(self, reference_seq: str, query_seq: str) -> Tuple[str, str]: + return self.cigar.to_msa(reference_seq[self.r_st:], query_seq[self.q_st:]) + + + def __repr__(self): + return f'CigarHit({str(self.cigar)!r}, r_st={self.r_st!r}, r_ei={self.r_ei!r}, q_st={self.q_st!r}, q_ei={self.q_ei!r})' + + +def connect_cigar_hits(cigar_hits: Iterable[CigarHit]) -> CigarHit: + """ + This function exists to deal with the fact that mappy does not always + connect big gaps, and returns surrounding parts as two separate alignment hits. + + For those cases we simply connect all the parts that do not overlap. + + Order of cigar_hits matters because we ignore alignments + that overlap with previously found alignments. + """ + + if not len(cigar_hits) > 0: + raise ValueError("Expected a non-empty list of cigar hits") + + accumulator = [] + + # Collect non-overlaping parts. + # Earlier matches have priority over ones that come after. + for hit in cigar_hits: + if any(earlier.overlaps(hit) for earlier in accumulator): + continue + + accumulator.append(hit) + + # Sort by interval start positions. + sorted_parts = sorted(accumulator, key=lambda p: p.r_st) + + # Collect all intervals back together, connecting them with CigarActions.DELETE. + return sum(sorted_parts[1:], start=sorted_parts[0]) diff --git a/micall/utils/consensus_aligner.py b/micall/utils/consensus_aligner.py index ab0bbdbc9..66e9ca6bf 100644 --- a/micall/utils/consensus_aligner.py +++ b/micall/utils/consensus_aligner.py @@ -23,9 +23,10 @@ # Most codons in an insertion or deletion that is still aligned in amino acids. MAXIMUM_AMINO_GAP = 10 +# Mapping as defined in https://samtools.github.io/hts-specs/SAMv1.pdf, page 8 CigarActions = IntEnum( 'CigarActions', - 'MATCH INSERT DELETE SKIPPED SOFT_CLIPPED HARD_CLIPPED', + 'MATCH INSERT DELETE SKIPPED SOFT_CLIPPED HARD_CLIPPED PADDING SEQ_MATCH MISMATCH', start=0)