Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 1, 2024
1 parent bf89ac0 commit 7813ed0
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 18 deletions.
2 changes: 0 additions & 2 deletions lshmm/fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def forwards_ls_hap(
e,
r,
emission_func,
*,
norm=True,
):
"""
Expand Down Expand Up @@ -97,7 +96,6 @@ def backwards_ls_hap(
e,
c,
r,
*,
emission_func,
):
"""
Expand Down
20 changes: 5 additions & 15 deletions lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def viterbi_naive_init(
s,
e,
r,
*,
emission_func,
):
"""Initialise a naive implementation."""
Expand Down Expand Up @@ -43,7 +42,6 @@ def viterbi_init(
s,
e,
r,
*,
emission_func,
):
"""Initialise a naive, but more memory efficient, implementation."""
Expand Down Expand Up @@ -73,11 +71,10 @@ def forwards_viterbi_hap_naive(
s,
e,
r,
*,
emission_func,
):
"""A naive implementation of the forward pass."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func)
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func=emission_func)

for j in range(1, m):
for i in range(n):
Expand Down Expand Up @@ -110,11 +107,10 @@ def forwards_viterbi_hap_naive_vec(
s,
e,
r,
*,
emission_func,
):
"""A naive matrix-based implementation of the forward pass."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func)
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func=emission_func)

for j in range(1, m):
v_tmp = V[j - 1, :] * r_n[j]
Expand Down Expand Up @@ -144,11 +140,10 @@ def forwards_viterbi_hap_naive_low_mem(
s,
e,
r,
*,
emission_func,
):
"""A naive implementation of the forward pass with reduced memory."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func)
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func=emission_func)

for j in range(1, m):
for i in range(n):
Expand Down Expand Up @@ -182,11 +177,10 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(
s,
e,
r,
*,
emission_func,
):
"""A naive implementation of the forward pass with reduced memory and rescaling."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func)
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func=emission_func)
c = np.ones(m)

for j in range(1, m):
Expand Down Expand Up @@ -223,11 +217,10 @@ def forwards_viterbi_hap_low_mem_rescaling(
s,
e,
r,
*,
emission_func,
):
"""An implementation with reduced memory that exploits the Markov structure."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func)
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func=emission_func)
c = np.ones(m)

for j in range(1, m):
Expand Down Expand Up @@ -263,7 +256,6 @@ def forwards_viterbi_hap_lower_mem_rescaling(
s,
e,
r,
*,
emission_func,
):
"""
Expand Down Expand Up @@ -317,7 +309,6 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
s,
e,
r,
*,
emission_func,
):
"""
Expand Down Expand Up @@ -409,7 +400,6 @@ def path_ll_hap(
s,
e,
r,
*,
emission_func,
):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/test_api_vit_haploid_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
e=e_vs,
r=r,
emission_func=emission_func,
norm=True,
)
path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs, P=P_vs)
path_ll_hap = vh.path_ll_hap(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nontree_vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
ll_check = vh.path_ll_hap(n=n, m=m, H=H_vs, path=path_tmp, s=s, e=e_vs, r=r, emission_func=emission_func,)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

Expand Down

0 comments on commit 7813ed0

Please sign in to comment.