diff --git a/micall/tests/test_cigar_tools.py b/micall/tests/test_cigar_tools.py index 6ab40b859..40a1fbc1a 100644 --- a/micall/tests/test_cigar_tools.py +++ b/micall/tests/test_cigar_tools.py @@ -4,7 +4,7 @@ import itertools from micall.utils.consensus_aligner import CigarActions -from micall.utils.cigar_tools import Cigar, CigarHit, parse_cigar_operation, CIGAR_OP_MAPPING +from micall.utils.cigar_tools import Cigar, CigarHit cigar_mapping_cases = [ @@ -330,6 +330,17 @@ def test_cigar_hit_ref_cut_add_prop_exhaustive(hit, cut_point): assert left + right == hit +@pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases + if not isinstance(x[2], Exception)]) +def test_cigar_hit_strip_combines_with_connect(hit, cut_point): + left, right = hit.cut_reference(cut_point) + + left = left.rstrip_query() + right = right.lstrip_query() + + assert left.connect(right).coordinate_mapping == hit.coordinate_mapping + + @pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases if not isinstance(x[2], Exception) and not 'N' in str(x[0].cigar)]) @@ -339,7 +350,8 @@ def test_cigar_hit_strip_combines_with_add(hit, cut_point): left = left.rstrip_query() right = right.lstrip_query() - assert (left + right).coordinate_mapping == hit.coordinate_mapping + if left.touches(right): + assert left + right == hit @pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases diff --git a/micall/utils/cigar_tools.py b/micall/utils/cigar_tools.py index b0a082fe7..78a462595 100644 --- a/micall/utils/cigar_tools.py +++ b/micall/utils/cigar_tools.py @@ -6,38 +6,14 @@ import re from typing import Container, Tuple, Iterable, Optional, Set, Dict from dataclasses import dataclass -from functools import cached_property +from functools import cached_property, reduce from itertools import chain, dropwhile from fractions import Fraction from micall.utils.consensus_aligner import CigarActions -CIGAR_OP_MAPPING = { - 'M': CigarActions.MATCH, - 'I': CigarActions.INSERT, - 'D': CigarActions.DELETE, - 'N': CigarActions.SKIPPED, - 'S': CigarActions.SOFT_CLIPPED, - 'H': CigarActions.HARD_CLIPPED, - 'P': CigarActions.PADDING, - '=': CigarActions.SEQ_MATCH, - 'X': CigarActions.MISMATCH, -} - - -def parse_cigar_operation(operation: str) -> CigarActions: - if operation in CIGAR_OP_MAPPING: - return CIGAR_OP_MAPPING[operation] - else: - raise ValueError(f"Unexpected CIGAR action: {operation}.") - - -def cigar_operation_to_str(op: CigarActions) -> str: - return [k for (k, v) in CIGAR_OP_MAPPING.items() if v == op][0] - - -class PartialDict(dict): +class IntDict(dict): def __init__(self): super().__init__() self.domain = set() # superset of self.keys() @@ -67,8 +43,8 @@ def right_min(self, index) -> Optional[int]: return min((v for (k, v) in self.items() if k >= index), default=None) - def translate(self, domain_delta: int, codomain_delta: int) -> 'PartialDict': - ret = PartialDict() + def translate(self, domain_delta: int, codomain_delta: int) -> 'IntDict': + ret = IntDict() for k, v in self.items(): ret.extend(k + domain_delta, v + codomain_delta) @@ -85,10 +61,10 @@ def translate(self, domain_delta: int, codomain_delta: int) -> 'PartialDict': @dataclass class CoordinateMapping: def __init__(self): - self.query_to_ref = PartialDict() - self.ref_to_query = PartialDict() - self.ref_to_op = PartialDict() - self.query_to_op = PartialDict() + self.query_to_ref = IntDict() + self.ref_to_query = IntDict() + self.ref_to_op = IntDict() + self.query_to_op = IntDict() def extend(self, @@ -169,6 +145,32 @@ def coerce(obj): raise TypeError(f"Cannot coerce {obj!r} to CIGAR string.") + OP_MAPPING = { + 'M': CigarActions.MATCH, + 'I': CigarActions.INSERT, + 'D': CigarActions.DELETE, + 'N': CigarActions.SKIPPED, + 'S': CigarActions.SOFT_CLIPPED, + 'H': CigarActions.HARD_CLIPPED, + 'P': CigarActions.PADDING, + '=': CigarActions.SEQ_MATCH, + 'X': CigarActions.MISMATCH, + } + + + @staticmethod + def parse_operation(operation: str) -> CigarActions: + if operation in Cigar.OP_MAPPING: + return Cigar.OP_MAPPING[operation] + else: + raise ValueError(f"Unexpected CIGAR action: {operation}.") + + + @staticmethod + def operation_to_str(op: CigarActions) -> str: + return [k for (k, v) in Cigar.OP_MAPPING.items() if v == op][0] + + @staticmethod def parse(string): data = [] @@ -176,7 +178,7 @@ def parse(string): match = re.match(r'([0-9]+)([^0-9])', string) if match: num, operation = match.groups() - data.append([int(num), parse_cigar_operation(operation)]) + data.append([int(num), Cigar.parse_operation(operation)]) string = string[match.end():] else: raise ValueError(f"Invalid CIGAR string. Invalid part: {string[:20]}") @@ -317,7 +319,7 @@ def __repr__(self): def __str__(self): """ Inverse of Cigar.parse """ - return ''.join('{}{}'.format(num, cigar_operation_to_str(op)) for num, op in self) + return ''.join('{}{}'.format(num, Cigar.operation_to_str(op)) for num, op in self) @dataclass @@ -376,6 +378,17 @@ def intervals_overlap(x, y): or intervals_overlap((self.q_st, self.q_ei), (other.q_st, other.q_ei)) + def touches(self, other) -> bool: + """ + Checks if this CIGAR hit touches the other CIGAR hit, + in both reference and query space. + NOTE: only applicable if these hits come from the same reference and query. + """ + + return self.r_ei + 1 == other.r_st \ + and self.q_ei + 1 == other.q_st + + def gaps(self) -> Iterable['CigarHit']: # TODO(vitalik): memoize whatever possible. @@ -404,24 +417,32 @@ def make_gap(r_st, r_en): def __add__(self, other): """ - Inserts deletions/insertions between self and other, - then ajusts boundaries appropriately. + Only adds CigarHits that are touching. + The addition is simply a concatenation of two Cigar strings, and adjustment of hit coordinates. """ - if self.overlaps(other): - raise ValueError("Cannot combine overlapping CIGAR hits") - - cigar = self.cigar \ - + CigarHit.from_default_alignment(self.r_ei + 1, other.r_st - 1, self.q_ei + 1, other.q_st - 1).cigar \ - + other.cigar + if not self.touches(other): + raise ValueError("Cannot combine CIGAR hits that do not touch in both reference and query coordinates") - return CigarHit(cigar=cigar, + return CigarHit(cigar=self.cigar + other.cigar, r_st=self.r_st, r_ei=other.r_ei, q_st=self.q_st, q_ei=other.q_ei, ) + def connect(self, other): + """ + Inserts deletions/insertions between self and other, + then ajusts boundaries appropriately. + """ + + if self.overlaps(other): + raise ValueError("Cannot combine overlapping CIGAR hits") + + filler = CigarHit.from_default_alignment(self.r_ei + 1, other.r_st - 1, self.q_ei + 1, other.q_st - 1) + return self + filler + other + @property def epsilon(self): @@ -556,4 +577,4 @@ def connect_cigar_hits(cigar_hits: Iterable[CigarHit]) -> CigarHit: sorted_parts = sorted(accumulator, key=lambda p: p.r_st) # Collect all intervals back together, connecting them with CigarActions.DELETE. - return sum(sorted_parts[1:], start=sorted_parts[0]) + return reduce(CigarHit.connect, sorted_parts)