Skip to content

Commit

Permalink
Add tests for naive vectorised implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Mar 22, 2024
1 parent 444491e commit fa0f27c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
5 changes: 4 additions & 1 deletion lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):
for i in range(n):
v = np.copy(v_tmp)
v[i] += V[j - 1, i] * (1 - r[j])
v *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != -2:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v *= em_prob
P[j, i] = np.argmax(v)
V[j, i] = v[P[j, i]]

Expand Down
33 changes: 33 additions & 0 deletions tests/test_API_noncopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,36 @@ def test_haploid_viterbi_multiallelic(

assert np.array_equal(expected_path, best_path)
assert expected_num_paths == num_best_paths


# Tests for naive matrix-based implementation using numpy.
@pytest.mark.parametrize(
"query, expected_path, expected_num_paths",
get_test_data_biallelic()
)
def test_haploid_viterbi_biallelic(query, expected_path, expected_num_paths):
H, m, n, r, e = get_example_data(use_multiallelic_sites=False)

V, P, _ = vh.forwards_viterbi_hap_naive_vec(n, m, H, query, e, r)
best_path = get_viterbi_path(V, P)
num_best_paths = np.sum(V[-1, :] == np.max(V[-1, :]))

assert np.array_equal(expected_path, best_path)
assert expected_num_paths == num_best_paths


@pytest.mark.parametrize(
"query, expected_path, expected_num_paths",
get_test_data_multiallelic()
)
def test_haploid_viterbi_multiallelic(
query, expected_path, expected_num_paths
):
H, m, n, r, e = get_example_data(use_multiallelic_sites=True)

V, P, _ = vh.forwards_viterbi_hap_naive_vec(n, m, H, query, e, r)
best_path = get_viterbi_path(V, P)
num_best_paths = np.sum(V[-1, :] == np.max(V[-1, :]))

assert np.array_equal(expected_path, best_path)
assert expected_num_paths == num_best_paths

0 comments on commit fa0f27c

Please sign in to comment.