Skip to content

Commit

Permalink
Contig stitcher: improve concordance calculations
Browse files Browse the repository at this point in the history
Also add more tests for it.
  • Loading branch information
Donaim committed Jan 23, 2024
1 parent bf1390f commit 4e0440b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
22 changes: 20 additions & 2 deletions micall/core/contig_stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,29 @@ def slide(start, end):
return result


def disambiguate_concordance(concordance: List[float]) -> List[Tuple[float, int]]:
def slide(concordance):
count = 0
for i, (prev, current, next) in enumerate(sliding_window(concordance)):
if current == prev:
count += 1
yield count
else:
yield 0

forward = list(slide(concordance))
reverse = list(reversed(list(slide(reversed(concordance)))))
for i, (x, f, r) in enumerate(zip(concordance, forward, reverse)):
local_rank = f * r
global_rank = i if i < len(concordance) / 2 else len(concordance) - i - 1
yield (x, local_rank, global_rank)


def concordance_to_cut_points(left_overlap, right_overlap, aligned_left, aligned_right, concordance):
""" Determine optimal cut points for stitching based on sequence concordance in the overlap region. """

valuator = lambda i: (concordance[i], i if i < len(concordance) / 2 else len(concordance) - i - 1)
sorted_concordance_indexes = sorted(range(len(concordance)), key=valuator)
concordance_d = list(disambiguate_concordance(concordance))
sorted_concordance_indexes = sorted(range(len(concordance)), key=lambda i: concordance_d[i])
remove_dashes = lambda s: ''.join(c for c in s if c != '-')

for max_concordance_index in reversed(sorted_concordance_indexes):
Expand Down
38 changes: 37 additions & 1 deletion micall/tests/test_contig_stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import pytest

from micall.core.contig_stitcher import split_contigs_with_gaps, stitch_contigs, GenotypedContig, merge_intervals, find_covered_contig, stitch_consensus, calculate_concordance, align_all_to_reference, main, AlignedContig
from micall.core.contig_stitcher import split_contigs_with_gaps, stitch_contigs, GenotypedContig, merge_intervals, find_covered_contig, stitch_consensus, calculate_concordance, align_all_to_reference, main, AlignedContig, disambiguate_concordance
from micall.core.plot_contigs import plot_stitcher_coverage
from micall.tests.utils import MockAligner, fixed_random_seed
from micall.utils.structured_logger import add_structured_handler
Expand Down Expand Up @@ -974,6 +974,42 @@ def generate_random_string_pair(length):
right = ''.join(random.choice('ACGT') for _ in range(length))
return left, right


@pytest.mark.parametrize(
'left, right, expected',
[("aaaaa", "aaaaa", [0.1] * 5),
("abcdd", "abcdd", [0.1] * 5),
("aaaaaaaa", "baaaaaab", [0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1]),
("aaaaaaaa", "aaaaaaab", [0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12]),
("aaaaaaaa", "aaaaaaab", [0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12]),
("aaaaaaaa", "aaaaabbb", [0.1, 0.1, 0.1, 0.1, 0.1, 0.08, 0.08, 0.08]),
("aaaaaaaa", "aaabbaaa", [0.12, 0.12, 0.12, 0.1, 0.1, 0.12, 0.12, 0.12]),
("aaaaa", "bbbbb", [0] * 5),
]
)
def test_concordance_simple(left, right, expected):
result = [round(x, 2) for x in calculate_concordance(left, right)]
assert result == expected


@pytest.mark.parametrize(
'left, right, expected',
[("a" * 128, "a" * 128, 64),
("a" * 128, "a" * 64 + "b" * 64, 32),
("a" * 128, "a" * 64 + "ba" * 32, 32),
("a" * 128, "a" * 54 + "b" * 20 + "a" * 54, 28), # two peaks
("a" * 128, "a" * 63 + "b" * 2 + "a" * 63, 32), # two peaks
("a" * 1280, "b" * 640 + "a" * 640, 640 + 30), # the window is too small to account for all of the context
]
)
def test_concordance_simple_index(left, right, expected):
concordance = calculate_concordance(left, right)
concordance_d = list(disambiguate_concordance(concordance))
index = max(range(len(concordance)), key=lambda i: concordance_d[i])
if abs(index - expected) > 3:
assert index == expected


def generate_test_cases(num_cases):
with fixed_random_seed(42):
length = random.randint(1, 80)
Expand Down

0 comments on commit 4e0440b

Please sign in to comment.