Skip to content

Commit

Permalink
Refactor and reorganise
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Mar 22, 2024
1 parent fa0f27c commit 717d9de
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions tests/test_API_noncopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
NC = NONCOPY # Sugar


# Helper functions
# TODO: Use the functions in the API instead.
def get_emission_probabilities(m, p_mutation, n_alleles):
# Note that this is different than `set_emission_probabilities` in `api.py`.
# No scaling.
Expand Down Expand Up @@ -60,7 +62,7 @@ def get_example_data(use_multiallelic_sites):
n_alleles = np.zeros(m, dtype=np.int8) + num_alleles
e = get_emission_probabilities(m, p_mutation, n_alleles)

return H, m, n, r, e
return n, m, H, e, r


def get_test_data_biallelic():
Expand Down Expand Up @@ -107,21 +109,6 @@ def get_test_data_biallelic():
]


@pytest.mark.parametrize(
"query, expected_path, expected_num_paths",
get_test_data_biallelic()
)
def test_haploid_viterbi_biallelic(query, expected_path, expected_num_paths):
H, m, n, r, e = get_example_data(use_multiallelic_sites=False)

V, P, _ = vh.forwards_viterbi_hap_naive(n, m, H, query, e, r)
best_path = get_viterbi_path(V, P)
num_best_paths = np.sum(V[-1, :] == np.max(V[-1, :]))

assert np.array_equal(expected_path, best_path)
assert expected_num_paths == num_best_paths


def get_test_data_multiallelic():
# Crossover from ancestor 3 to ancestor 1.
query_m_a3_x_a1 = np.array([[ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]])
Expand All @@ -141,14 +128,30 @@ def get_test_data_multiallelic():
]


# Tests for naive matrix-based implementation.
@pytest.mark.parametrize(
"query, expected_path, expected_num_paths",
get_test_data_biallelic()
)
def test_forwards_viterbi_hap_naive_biallelic(query, expected_path, expected_num_paths):
n, m, H, e, r = get_example_data(use_multiallelic_sites=False)

V, P, _ = vh.forwards_viterbi_hap_naive(n, m, H, query, e, r)
best_path = get_viterbi_path(V, P)
num_best_paths = np.sum(V[-1, :] == np.max(V[-1, :]))

assert np.array_equal(expected_path, best_path)
assert expected_num_paths == num_best_paths


@pytest.mark.parametrize(
"query, expected_path, expected_num_paths",
get_test_data_multiallelic()
)
def test_haploid_viterbi_multiallelic(
def test_forwards_viterbi_hap_naive_multiallelic(
query, expected_path, expected_num_paths
):
H, m, n, r, e = get_example_data(use_multiallelic_sites=True)
n, m, H, e, r = get_example_data(use_multiallelic_sites=True)

V, P, _ = vh.forwards_viterbi_hap_naive(n, m, H, query, e, r)
best_path = get_viterbi_path(V, P)
Expand All @@ -163,8 +166,8 @@ def test_haploid_viterbi_multiallelic(
"query, expected_path, expected_num_paths",
get_test_data_biallelic()
)
def test_haploid_viterbi_biallelic(query, expected_path, expected_num_paths):
H, m, n, r, e = get_example_data(use_multiallelic_sites=False)
def test_forwards_viterbi_hap_naive_vec_biallelic(query, expected_path, expected_num_paths):
n, m, H, e, r = get_example_data(use_multiallelic_sites=False)

V, P, _ = vh.forwards_viterbi_hap_naive_vec(n, m, H, query, e, r)
best_path = get_viterbi_path(V, P)
Expand All @@ -178,10 +181,10 @@ def test_haploid_viterbi_biallelic(query, expected_path, expected_num_paths):
"query, expected_path, expected_num_paths",
get_test_data_multiallelic()
)
def test_haploid_viterbi_multiallelic(
def test_forwards_viterbi_hap_naive_vec_multiallelic(
query, expected_path, expected_num_paths
):
H, m, n, r, e = get_example_data(use_multiallelic_sites=True)
n, m, H, e, r = get_example_data(use_multiallelic_sites=True)

V, P, _ = vh.forwards_viterbi_hap_naive_vec(n, m, H, query, e, r)
best_path = get_viterbi_path(V, P)
Expand Down

0 comments on commit 717d9de

Please sign in to comment.