Skip to content

Commit

Permalink
Cigar tools: divide __add__ operation into connect and basic __add__
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Nov 13, 2023
1 parent 712ef3b commit 66fcafd
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 46 deletions.
16 changes: 14 additions & 2 deletions micall/tests/test_cigar_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools

from micall.utils.consensus_aligner import CigarActions
from micall.utils.cigar_tools import Cigar, CigarHit, parse_cigar_operation, CIGAR_OP_MAPPING
from micall.utils.cigar_tools import Cigar, CigarHit


cigar_mapping_cases = [
Expand Down Expand Up @@ -330,6 +330,17 @@ def test_cigar_hit_ref_cut_add_prop_exhaustive(hit, cut_point):
assert left + right == hit


@pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases
if not isinstance(x[2], Exception)])
def test_cigar_hit_strip_combines_with_connect(hit, cut_point):
left, right = hit.cut_reference(cut_point)

left = left.rstrip_query()
right = right.lstrip_query()

assert left.connect(right).coordinate_mapping == hit.coordinate_mapping


@pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases
if not isinstance(x[2], Exception)
and not 'N' in str(x[0].cigar)])
Expand All @@ -339,7 +350,8 @@ def test_cigar_hit_strip_combines_with_add(hit, cut_point):
left = left.rstrip_query()
right = right.lstrip_query()

assert (left + right).coordinate_mapping == hit.coordinate_mapping
if left.touches(right):
assert left + right == hit


@pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases
Expand Down
109 changes: 65 additions & 44 deletions micall/utils/cigar_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,14 @@
import re
from typing import Container, Tuple, Iterable, Optional, Set, Dict
from dataclasses import dataclass
from functools import cached_property
from functools import cached_property, reduce
from itertools import chain, dropwhile
from fractions import Fraction

from micall.utils.consensus_aligner import CigarActions


CIGAR_OP_MAPPING = {
'M': CigarActions.MATCH,
'I': CigarActions.INSERT,
'D': CigarActions.DELETE,
'N': CigarActions.SKIPPED,
'S': CigarActions.SOFT_CLIPPED,
'H': CigarActions.HARD_CLIPPED,
'P': CigarActions.PADDING,
'=': CigarActions.SEQ_MATCH,
'X': CigarActions.MISMATCH,
}


def parse_cigar_operation(operation: str) -> CigarActions:
if operation in CIGAR_OP_MAPPING:
return CIGAR_OP_MAPPING[operation]
else:
raise ValueError(f"Unexpected CIGAR action: {operation}.")


def cigar_operation_to_str(op: CigarActions) -> str:
return [k for (k, v) in CIGAR_OP_MAPPING.items() if v == op][0]


class PartialDict(dict):
class IntDict(dict):
def __init__(self):
super().__init__()
self.domain = set() # superset of self.keys()
Expand Down Expand Up @@ -67,8 +43,8 @@ def right_min(self, index) -> Optional[int]:
return min((v for (k, v) in self.items() if k >= index), default=None)


def translate(self, domain_delta: int, codomain_delta: int) -> 'PartialDict':
ret = PartialDict()
def translate(self, domain_delta: int, codomain_delta: int) -> 'IntDict':
ret = IntDict()

for k, v in self.items():
ret.extend(k + domain_delta, v + codomain_delta)
Expand All @@ -85,10 +61,10 @@ def translate(self, domain_delta: int, codomain_delta: int) -> 'PartialDict':
@dataclass
class CoordinateMapping:
def __init__(self):
self.query_to_ref = PartialDict()
self.ref_to_query = PartialDict()
self.ref_to_op = PartialDict()
self.query_to_op = PartialDict()
self.query_to_ref = IntDict()
self.ref_to_query = IntDict()
self.ref_to_op = IntDict()
self.query_to_op = IntDict()


def extend(self,
Expand Down Expand Up @@ -169,14 +145,40 @@ def coerce(obj):
raise TypeError(f"Cannot coerce {obj!r} to CIGAR string.")


OP_MAPPING = {
'M': CigarActions.MATCH,
'I': CigarActions.INSERT,
'D': CigarActions.DELETE,
'N': CigarActions.SKIPPED,
'S': CigarActions.SOFT_CLIPPED,
'H': CigarActions.HARD_CLIPPED,
'P': CigarActions.PADDING,
'=': CigarActions.SEQ_MATCH,
'X': CigarActions.MISMATCH,
}


@staticmethod
def parse_operation(operation: str) -> CigarActions:
if operation in Cigar.OP_MAPPING:
return Cigar.OP_MAPPING[operation]
else:
raise ValueError(f"Unexpected CIGAR action: {operation}.")


@staticmethod
def operation_to_str(op: CigarActions) -> str:
return [k for (k, v) in Cigar.OP_MAPPING.items() if v == op][0]


@staticmethod
def parse(string):
data = []
while string:
match = re.match(r'([0-9]+)([^0-9])', string)
if match:
num, operation = match.groups()
data.append([int(num), parse_cigar_operation(operation)])
data.append([int(num), Cigar.parse_operation(operation)])
string = string[match.end():]
else:
raise ValueError(f"Invalid CIGAR string. Invalid part: {string[:20]}")
Expand Down Expand Up @@ -317,7 +319,7 @@ def __repr__(self):

def __str__(self):
""" Inverse of Cigar.parse """
return ''.join('{}{}'.format(num, cigar_operation_to_str(op)) for num, op in self)
return ''.join('{}{}'.format(num, Cigar.operation_to_str(op)) for num, op in self)


@dataclass
Expand Down Expand Up @@ -376,6 +378,17 @@ def intervals_overlap(x, y):
or intervals_overlap((self.q_st, self.q_ei), (other.q_st, other.q_ei))


def touches(self, other) -> bool:
"""
Checks if this CIGAR hit touches the other CIGAR hit,
in both reference and query space.
NOTE: only applicable if these hits come from the same reference and query.
"""

return self.r_ei + 1 == other.r_st \
and self.q_ei + 1 == other.q_st


def gaps(self) -> Iterable['CigarHit']:
# TODO(vitalik): memoize whatever possible.

Expand Down Expand Up @@ -404,24 +417,32 @@ def make_gap(r_st, r_en):

def __add__(self, other):
"""
Inserts deletions/insertions between self and other,
then ajusts boundaries appropriately.
Only adds CigarHits that are touching.
The addition is simply a concatenation of two Cigar strings, and adjustment of hit coordinates.
"""

if self.overlaps(other):
raise ValueError("Cannot combine overlapping CIGAR hits")

cigar = self.cigar \
+ CigarHit.from_default_alignment(self.r_ei + 1, other.r_st - 1, self.q_ei + 1, other.q_st - 1).cigar \
+ other.cigar
if not self.touches(other):
raise ValueError("Cannot combine CIGAR hits that do not touch in both reference and query coordinates")

return CigarHit(cigar=cigar,
return CigarHit(cigar=self.cigar + other.cigar,
r_st=self.r_st,
r_ei=other.r_ei,
q_st=self.q_st,
q_ei=other.q_ei,
)

def connect(self, other):
"""
Inserts deletions/insertions between self and other,
then ajusts boundaries appropriately.
"""

if self.overlaps(other):
raise ValueError("Cannot combine overlapping CIGAR hits")

filler = CigarHit.from_default_alignment(self.r_ei + 1, other.r_st - 1, self.q_ei + 1, other.q_st - 1)
return self + filler + other


@property
def epsilon(self):
Expand Down Expand Up @@ -556,4 +577,4 @@ def connect_cigar_hits(cigar_hits: Iterable[CigarHit]) -> CigarHit:
sorted_parts = sorted(accumulator, key=lambda p: p.r_st)

# Collect all intervals back together, connecting them with CigarActions.DELETE.
return sum(sorted_parts[1:], start=sorted_parts[0])
return reduce(CigarHit.connect, sorted_parts)

0 comments on commit 66fcafd

Please sign in to comment.