Skip to content

Commit

Permalink
Add contig stitcher module
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Nov 6, 2023
1 parent 951ca56 commit 46dbd45
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 45 deletions.
145 changes: 107 additions & 38 deletions micall/core/contig_stitcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import logging
import os
from typing import Iterable, Optional
from collections import Counter
from typing import Iterable, Optional, Tuple, List
from collections import Counter, deque
from csv import DictWriter, DictReader
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -14,6 +14,8 @@
from subprocess import run, PIPE, CalledProcessError, STDOUT
from tempfile import mkdtemp
from mappy import Aligner
from functools import cached_property
from gotoh import align_it

from micall.utils.cigar_tools import connect_cigar_hits, CigarHit

Expand All @@ -30,43 +32,89 @@ class GenotypedContig(Contig):
ref_seq: str
matched_fraction: Optional[float] # Approximated overall concordance between `seq` and `ref_seq`.

@property
def contig(self):
return self

@dataclass
class AlignedContig:
contig: GenotypedContig
alignment: CigarHit

def reference_slice(self, r_st, r_en):
""" Narrows this alignment to a more specific region in reference. """
def cut_reference(self, cut_point: float) -> Tuple['AlignedContig', 'AlignedContig']:
""" Cuts this alignment in two parts with cut_point between them. """

alignment_left, alignment_right = self.alignment.cut_reference(cut_point)
return (AlignedContig(self.contig, alignment_left),
AlignedContig(self.contig, alignment_right))


@cached_property
def msa(self):
return self.alignment.to_msa(self.contig.ref_seq, self.contig.seq)


@cached_property
def seq(self):
seq_left, ref_seq_left = self.msa
return ''.join((c for c in ref_seq_left if c != '-'))


class FrankensteinContig(AlignedContig):
""" Assembled of parts that were not even aligned together,
and of some parts that were not aligned at all.
Yet its .seq string looks like a real contig. """

def __init__(self, parts: List[GenotypedContig]):
self.parts = [subpart for part in parts for subpart in
(part.parts if isinstance(part, FrankensteinContig) else [part])]

name = '+'.join(map(lambda acontig: acontig.contig.name, self.parts))
ref = self.parts[0].contig
contig = GenotypedContig(name=name, seq=self.seq,
ref_name=ref.ref_name,
ref_seq=ref.ref_seq,
matched_fraction=ref.matched_fraction)

alignment = connect_cigar_hits([part.alignment for part in self.parts
if isinstance(part, AlignedContig)])

super().__init__(contig, alignment)


@cached_property
def seq(self):
return ''.join(map(lambda part: part.seq, self.parts))

alignment = self.alignment.reference_slice(r_st, r_en)
contig = self.contig
return AlignedContig(contig, alignment)


def align_to_reference(contig: GenotypedContig):
aligner = Aligner(seq=contig.ref_seq, bw=500, bw_long=500, preset='map-ont')
alignments = list(aligner.map(contig.seq))
if not alignments:
return AlignedContig(contig=contig, alignment=None)
return contig

hits_array = [CigarHit(x.cigar, x.r_st, x.r_en - 1, x.q_st, x.q_en - 1) for x in alignments]
single_cigar_hit = connect_cigar_hits(hits_array)
return AlignedContig(contig=contig, alignment=single_cigar_hit)


def intervals_overlap(x, y):
""" Check if two intervals (x0, x1) and (y0, y1) overlap. """
def align_equal(seq1, seq2) -> Tuple[str, str]:
gap_open_penalty = 15
gap_extend_penalty = 3
use_terminal_gap_penalty = 1
aseq1, aseq2, score = \
align_it(
seq1, seq2,
gap_open_penalty,
gap_extend_penalty,
use_terminal_gap_penalty)

if x[0] > y[1] or y[0] > x[1]:
return False
else:
return True
return aseq1, aseq2


def interval_contains(x, y):
""" Check if interval (x0, x1) contains interval (y0, y1). """

return x[0] <= y[0] and x[1] >= y[1]


Expand All @@ -81,37 +129,60 @@ def find_all_overlapping_contigs(self, aligned_contigs):

def find_overlapping_contig(self, aligned_contigs):
every = find_all_overlapping_contigs(self, aligned_contigs)
return max(chain(every, [None]),
key=lambda other: other.alignment.r_ei - other.alignment.r_st if other else 0)
return max(every, key=lambda other: other.alignment.r_ei - other.alignment.r_st if other else 0,
default=None)


def calculate_concordance(left: str, right: str) -> Iterable[float]:
window_size = 10
scores = deque([0] * window_size, maxlen=window_size)
scores_sum = 0
result = []

def calculate_overlap(left, right):
left_seq_0 = left.contig.seq
right_seq_0 = right.contig.seq
assert len(left) == len(right), "Can only calculate concordance for same sized sequences"

left_mapping = left.alignment.coordinate_mapping
right_mapping = right.alignment.coordinate_mapping
for (a, b) in zip(left, right):
current = a == b
scores_sum -= scores.popleft()
scores_sum += (a == b)
scores.append(current)
result.append(scores_sum / window_size)

overlap_interval = (right.alignment.r_st, left.alignment.r_ei)
# left_overlap_seq = [left.contig.ref_seq[i] for i in ]
return result

left_overlap = left.reference_slice(overlap_interval[0], overlap_interval[1])
right_overlap = right.reference_slice(overlap_interval[0], overlap_interval[1])


def stitch_2_contigs(left, right):
# Cut in 4 parts.
left_remainder, left_overlap = left.cut_reference(right.alignment.r_st - 0.5)
right_overlap, right_remainder = right.cut_reference(left.alignment.r_ei + 0.5)

# Align overlapping parts, then recombine based on concordance.
aligned_left, aligned_right = align_equal(left_overlap.seq, right_overlap.seq)
concordance = calculate_concordance(aligned_left, aligned_right)
max_concordance_index = max(range(len(concordance)),
key=lambda i: concordance[i])
aligned_left_part = aligned_left[:max_concordance_index]
aligned_right_part = aligned_right[max_concordance_index:]
overlap_seq = ''.join(c for c in aligned_left_part + aligned_right_part if c != '-')

# Return something that can be fed back into the loop.
overlap_contig = GenotypedContig(name=f'overlap({left.contig.name},{right.contig.name})',
seq=overlap_seq, ref_name=left.contig.ref_name,
ref_seq=left.contig.ref_seq, matched_fraction=None)
return FrankensteinContig([left_remainder, overlap_contig, right_remainder])


def stitch_contigs(contigs: Iterable[GenotypedContig]):
aligned = list(map(align_to_reference, contigs))

# Contigs that did not align do not need any more processing
stitched = [x.contig for x in aligned if x.alignment is None]
aligned = [x for x in aligned if x.alignment is not None]
stitched = yield from (x for x in aligned if not isinstance(x, AlignedContig))
aligned = [x for x in aligned if isinstance(x, AlignedContig)]

# Going left-to-right through aligned parts.
aligned = list(sorted(aligned, key=lambda x: x.alignment.r_st))
while aligned:
current = aligned.pop(0)
# Going left-to-right through aligned parts.
current = min(aligned, key=lambda x: x.alignment.r_st)
aligned.remove(current)

# Filter out all contigs that are contained within the current one.
# TODO: actually filter out if covered by multiple contigs
Expand All @@ -122,12 +193,10 @@ def stitch_contigs(contigs: Iterable[GenotypedContig]):
# Find overlap. If there isn't one - we are done with the current contig.
overlapping_contig = find_overlapping_contig(current, aligned)
if not overlapping_contig:
stitched.append(current.contig)
yield current
continue

# Get overlaping regions
overlap = calculate_overlap(current, overlapping_contig)

# aligned.append(combined)

return stitched
new_contig = stitch_2_contigs(current, overlapping_contig)
aligned.remove(overlapping_contig)
aligned.append(new_contig)
16 changes: 9 additions & 7 deletions micall/tests/test_contig_stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ def test_1():
),
]

result = stitch_contigs(contigs)
assert result == contigs
result = list(stitch_contigs(contigs))
assert sorted(map(lambda x: x.seq, contigs)) \
== sorted(map(lambda x: x.seq, result))


def test_2():
Expand All @@ -34,9 +35,9 @@ def test_2():
),
]

result = stitch_contigs(contigs)
assert sorted(result, key=lambda x: x.name) \
== sorted(contigs, key=lambda x: x.name)
result = list(stitch_contigs(contigs))
assert sorted(map(lambda x: x.seq, contigs)) \
== sorted(map(lambda x: x.seq, result))


def test_3():
Expand All @@ -57,5 +58,6 @@ def test_3():
),
]

result = stitch_contigs(contigs)
assert False
result = list(stitch_contigs(contigs))
assert 100 == sum(len(x.seq) for x in result)
assert result[0].contig.name == 'a+overlap(a,b)+b'

0 comments on commit 46dbd45

Please sign in to comment.