From 4e0440b182da1b0c2a50e2ac0ec33b426890f5c5 Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Tue, 23 Jan 2024 14:37:37 -0800 Subject: [PATCH] Contig stitcher: improve concordance calculations Also add more tests for it. --- micall/core/contig_stitcher.py | 22 ++++++++++++++-- micall/tests/test_contig_stitcher.py | 38 +++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/micall/core/contig_stitcher.py b/micall/core/contig_stitcher.py index 9d15720a9..9fd93457d 100644 --- a/micall/core/contig_stitcher.py +++ b/micall/core/contig_stitcher.py @@ -384,11 +384,29 @@ def slide(start, end): return result +def disambiguate_concordance(concordance: List[float]) -> List[Tuple[float, int]]: + def slide(concordance): + count = 0 + for i, (prev, current, next) in enumerate(sliding_window(concordance)): + if current == prev: + count += 1 + yield count + else: + yield 0 + + forward = list(slide(concordance)) + reverse = list(reversed(list(slide(reversed(concordance))))) + for i, (x, f, r) in enumerate(zip(concordance, forward, reverse)): + local_rank = f * r + global_rank = i if i < len(concordance) / 2 else len(concordance) - i - 1 + yield (x, local_rank, global_rank) + + def concordance_to_cut_points(left_overlap, right_overlap, aligned_left, aligned_right, concordance): """ Determine optimal cut points for stitching based on sequence concordance in the overlap region. """ - valuator = lambda i: (concordance[i], i if i < len(concordance) / 2 else len(concordance) - i - 1) - sorted_concordance_indexes = sorted(range(len(concordance)), key=valuator) + concordance_d = list(disambiguate_concordance(concordance)) + sorted_concordance_indexes = sorted(range(len(concordance)), key=lambda i: concordance_d[i]) remove_dashes = lambda s: ''.join(c for c in s if c != '-') for max_concordance_index in reversed(sorted_concordance_indexes): diff --git a/micall/tests/test_contig_stitcher.py b/micall/tests/test_contig_stitcher.py index f6193688f..3fff183fb 100644 --- a/micall/tests/test_contig_stitcher.py +++ b/micall/tests/test_contig_stitcher.py @@ -4,7 +4,7 @@ import os import pytest -from micall.core.contig_stitcher import split_contigs_with_gaps, stitch_contigs, GenotypedContig, merge_intervals, find_covered_contig, stitch_consensus, calculate_concordance, align_all_to_reference, main, AlignedContig +from micall.core.contig_stitcher import split_contigs_with_gaps, stitch_contigs, GenotypedContig, merge_intervals, find_covered_contig, stitch_consensus, calculate_concordance, align_all_to_reference, main, AlignedContig, disambiguate_concordance from micall.core.plot_contigs import plot_stitcher_coverage from micall.tests.utils import MockAligner, fixed_random_seed from micall.utils.structured_logger import add_structured_handler @@ -974,6 +974,42 @@ def generate_random_string_pair(length): right = ''.join(random.choice('ACGT') for _ in range(length)) return left, right + +@pytest.mark.parametrize( + 'left, right, expected', + [("aaaaa", "aaaaa", [0.1] * 5), + ("abcdd", "abcdd", [0.1] * 5), + ("aaaaaaaa", "baaaaaab", [0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.1]), + ("aaaaaaaa", "aaaaaaab", [0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12]), + ("aaaaaaaa", "aaaaaaab", [0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12]), + ("aaaaaaaa", "aaaaabbb", [0.1, 0.1, 0.1, 0.1, 0.1, 0.08, 0.08, 0.08]), + ("aaaaaaaa", "aaabbaaa", [0.12, 0.12, 0.12, 0.1, 0.1, 0.12, 0.12, 0.12]), + ("aaaaa", "bbbbb", [0] * 5), + ] +) +def test_concordance_simple(left, right, expected): + result = [round(x, 2) for x in calculate_concordance(left, right)] + assert result == expected + + +@pytest.mark.parametrize( + 'left, right, expected', + [("a" * 128, "a" * 128, 64), + ("a" * 128, "a" * 64 + "b" * 64, 32), + ("a" * 128, "a" * 64 + "ba" * 32, 32), + ("a" * 128, "a" * 54 + "b" * 20 + "a" * 54, 28), # two peaks + ("a" * 128, "a" * 63 + "b" * 2 + "a" * 63, 32), # two peaks + ("a" * 1280, "b" * 640 + "a" * 640, 640 + 30), # the window is too small to account for all of the context + ] +) +def test_concordance_simple_index(left, right, expected): + concordance = calculate_concordance(left, right) + concordance_d = list(disambiguate_concordance(concordance)) + index = max(range(len(concordance)), key=lambda i: concordance_d[i]) + if abs(index - expected) > 3: + assert index == expected + + def generate_test_cases(num_cases): with fixed_random_seed(42): length = random.randint(1, 80)