Skip to content

Commit

Permalink
Add tests for calculate_concordance
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Nov 14, 2023
1 parent aaf2a28 commit 98e9240
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
2 changes: 1 addition & 1 deletion micall/core/contig_stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def calculate_concordance(left: str, right: str) -> List[float]:
The function compares the two strings from both left to right and then right to left,
calculating for each position the ratio of matching characters in a window around the
current position (10 characters to the left and right).
current position.
It's required that the input strings are of the same length.
Expand Down
74 changes: 72 additions & 2 deletions micall/tests/test_contig_stitcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

import pytest
from micall.core.contig_stitcher import split_contigs_with_gaps, stitch_contigs, GenotypedContig, merge_intervals, find_covered_contig, stitch_consensus
from micall.tests.utils import MockAligner
import random
from micall.core.contig_stitcher import split_contigs_with_gaps, stitch_contigs, GenotypedContig, merge_intervals, find_covered_contig, stitch_consensus, calculate_concordance
from micall.tests.utils import MockAligner, fixed_random_seed


@pytest.fixture()
Expand Down Expand Up @@ -732,3 +733,72 @@ def test_find_covered(contigs, expected_covered_name):
assert covered is not None
assert covered.name == expected_covered_name


def test_concordance_same_length_inputs():
with pytest.raises(ValueError):
calculate_concordance('abc', 'ab')

def test_concordance_completely_different_strings():
result = calculate_concordance('a'*30, 'b'*30)
assert all(n == 0 for n in result)

def generate_random_string_pair(length):
left = ''.join(random.choice('ACGT') for _ in range(length))
right = ''.join(random.choice('ACGT') for _ in range(length))
return left, right

def generate_test_cases(num_cases):
with fixed_random_seed(42):
length = random.randint(1, 80)
return [generate_random_string_pair(length) for _ in range(num_cases)]

concordance_cases = generate_test_cases(num_cases=100)


@pytest.mark.parametrize('left, right', concordance_cases)
def test_concordance_output_is_list_of_floats(left, right):
result = calculate_concordance(left, right)
assert isinstance(result, list), "Result should be a list"
assert all(isinstance(n, float) for n in result), "All items in result should be float"


@pytest.mark.parametrize('left, right', concordance_cases)
def test_concordance_output_range(left, right):
result = calculate_concordance(left, right)
assert all(0 <= n <= 1 for n in result), "All values in result should be between 0 and 1"


@pytest.mark.parametrize('left, right', concordance_cases)
def test_concordance_higher_if_more_matches_added(left, right):
# Insert exact matches in the middle
matching_sequence = 'A' * 30
insert_position = len(left) // 2
new_left = left[:insert_position] + matching_sequence + left[insert_position + len(matching_sequence):]
new_right = right[:insert_position] + matching_sequence + right[insert_position + len(matching_sequence):]

old_conc = calculate_concordance(left, right)
new_conc = calculate_concordance(new_left, new_right)
old_average = sum(old_conc) / len(old_conc)
new_average = sum(new_conc) / len(new_conc)
assert old_average <= new_average


@pytest.mark.parametrize('left, right', concordance_cases)
def test_concordance_higher_in_matching_areas(left, right):
# Insert exact matches in the middle
matching_sequence = 'A' * 30
insert_position = len(left) // 2
new_left = left[:insert_position] + matching_sequence + left[insert_position + len(matching_sequence):]
new_right = right[:insert_position] + matching_sequence + right[insert_position + len(matching_sequence):]

concordance_scores = calculate_concordance(new_left, new_right)

# Check concordance in the matching area
matching_area_concordance = concordance_scores[insert_position:insert_position + len(matching_sequence)]

# Calculate average concordance inside and outside the matching area
average_inside = sum(matching_area_concordance) / len(matching_sequence)
average_outside = (sum(concordance_scores) - sum(matching_area_concordance)) / (len(concordance_scores) - len(matching_sequence))

# Assert that the concordance is indeed higher in the matching area
assert average_inside > average_outside, "Concordance in matching areas should be higher than in non-matching areas"

0 comments on commit 98e9240

Please sign in to comment.