Skip to content

Commit

Permalink
Reference using argument names
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 1, 2024
1 parent 7a74c7f commit 01736b2
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 26 deletions.
42 changes: 38 additions & 4 deletions tests/test_nontree_fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,47 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
include_extreme_rates=True,
):
emission_func = core.get_emission_probability_haploid
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, emission_func, norm=False)
B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r, emission_func)
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
norm=False,
)
B_vs = fbh.backwards_ls_hap(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
c=c_vs,
r=r,
emission_func=emission_func,
)
self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m))
F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap(
n, m, H_vs, s, e_vs, r, emission_func, norm=True,
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
norm=True,
)
B_tmp = fbh.backwards_ls_hap(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
c=c_tmp,
r=r,
emission_func=emission_func,
)
B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r, emission_func)
self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m))
self.assertAllClose(ll_vs, ll_tmp)

Expand Down
152 changes: 130 additions & 22 deletions tests/test_nontree_vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,140 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r)
path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs)
ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r)
emission_func = core.get_emission_probability_haploid

V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_vs = vh.backwards_viterbi_hap(
m=m,
V_last=V_vs[m - 1, :],
P=P_vs,
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
self.assertAllClose(ll_vs, ll_check)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec(
n, m, H_vs, s, e_vs, r
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(
m=m,
V_last=V_tmp[m - 1, :],
P=P_tmp,
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem(
n, m, H_vs, s, e_vs, r
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling(
n, m, H_vs, s, e_vs, r
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling(
n, m, H_vs, s, e_vs, r
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling(
n, m, H_vs, s, e_vs, r
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)
Expand All @@ -69,14 +162,29 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
recombs,
ll_tmp,
) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
n, m, H_vs, s, e_vs, r
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap_no_pointer(
m,
V_argmaxes_tmp,
nb.typed.List(recombs),
m=m,
V_argmaxes=V_argmaxes_tmp,
recombs=nb.typed.List(recombs),
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

Expand Down

0 comments on commit 01736b2

Please sign in to comment.