Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 1, 2024
1 parent 692b477 commit 7a74c7f
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 25 deletions.
4 changes: 2 additions & 2 deletions lshmm/fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@jit.numba_njit
def forwards_ls_hap(
n, m, H, s, e, r, norm=True, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, emission_func, *, norm=True,
):
"""
A matrix-based implementation using Numpy.
Expand Down Expand Up @@ -82,7 +82,7 @@ def forwards_ls_hap(

@jit.numba_njit
def backwards_ls_hap(
n, m, H, s, e, c, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, c, r, *, emission_func,
):
"""
A matrix-based implementation using Numpy.
Expand Down
20 changes: 10 additions & 10 deletions lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@jit.numba_njit
def viterbi_naive_init(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, *, emission_func,
):
"""Initialise a naive implementation."""
V = np.zeros((m, n))
Expand All @@ -30,7 +30,7 @@ def viterbi_naive_init(

@jit.numba_njit
def viterbi_init(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, *, emission_func,
):
"""Initialise a naive, but more memory efficient, implementation."""
V_prev = np.zeros(n)
Expand All @@ -53,7 +53,7 @@ def viterbi_init(

@jit.numba_njit
def forwards_viterbi_hap_naive(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, *, emission_func,
):
"""A naive implementation of the forward pass."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func)
Expand Down Expand Up @@ -83,7 +83,7 @@ def forwards_viterbi_hap_naive(

@jit.numba_njit
def forwards_viterbi_hap_naive_vec(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, *, emission_func,
):
"""A naive matrix-based implementation of the forward pass."""
V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func)
Expand All @@ -110,7 +110,7 @@ def forwards_viterbi_hap_naive_vec(

@jit.numba_njit
def forwards_viterbi_hap_naive_low_mem(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, *, emission_func,
):
"""A naive implementation of the forward pass with reduced memory."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func)
Expand Down Expand Up @@ -141,7 +141,7 @@ def forwards_viterbi_hap_naive_low_mem(

@jit.numba_njit
def forwards_viterbi_hap_naive_low_mem_rescaling(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, *, emission_func,
):
"""A naive implementation of the forward pass with reduced memory and rescaling."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func)
Expand Down Expand Up @@ -175,7 +175,7 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(

@jit.numba_njit
def forwards_viterbi_hap_low_mem_rescaling(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, *, emission_func,
):
"""An implementation with reduced memory that exploits the Markov structure."""
V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func)
Expand Down Expand Up @@ -208,7 +208,7 @@ def forwards_viterbi_hap_low_mem_rescaling(

@jit.numba_njit
def forwards_viterbi_hap_lower_mem_rescaling(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, *, emission_func,
):
"""
An implementation with even smaller memory footprint
Expand Down Expand Up @@ -255,7 +255,7 @@ def forwards_viterbi_hap_lower_mem_rescaling(

@jit.numba_njit
def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, s, e, r, *, emission_func,
):
"""
An implementation with even smaller memory footprint and rescaling
Expand Down Expand Up @@ -339,7 +339,7 @@ def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs):

@jit.numba_njit
def path_ll_hap(
n, m, H, path, s, e, r, *, emission_func=core.get_emission_probability_haploid
n, m, H, path, s, e, r, *, emission_func,
):
"""
Evaluate the log-likelihood of a path through a reference panel resulting in a query.
Expand Down
5 changes: 4 additions & 1 deletion tests/test_api_fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
num_alleles = core.get_num_alleles(H_vs, s)
emission_func = core.get_emission_probability_haploid
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
norm=True,
)
B_vs = fbh.backwards_ls_hap(
n=n,
Expand All @@ -33,6 +35,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
e=e_vs,
c=c_vs,
r=r,
emission_func=emission_func,
)
F, c, ll = ls.forwards(
reference_panel=H_vs,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_api_fb_haploid_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
emission_func = core.get_emission_probability_haploid
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
B_vs = fbh.backwards_ls_hap(
n=n,
Expand All @@ -32,6 +34,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
e=e_vs,
c=c_vs,
r=r,
emission_func=emission_func,
)
F, c, ll = ls.forwards(
reference_panel=H_vs,
Expand Down
19 changes: 13 additions & 6 deletions tests/test_api_vit_haploid_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,27 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
emission_func = core.get_emission_probability_haploid
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs, P=P_vs)
path_ll_hap = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r)
path_ll_hap = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path, ll = ls.viterbi(
reference_panel=H_vs,
query=s,
Expand All @@ -44,11 +55,7 @@ def test_ts_multiallelic_n10_no_recomb(
self, scale_mutation_rate, include_ancestors
):
ts = self.get_ts_multiallelic_n10_no_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("num_samples", [6, 8, 16])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
Expand Down
12 changes: 7 additions & 5 deletions tests/test_nontree_fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@

class TestNonTreeForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 1
for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars(
ts,
ploidy=1,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False)
B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r)
emission_func = core.get_emission_probability_haploid
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, emission_func, norm=False)
B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r, emission_func)
self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m))
F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap(
n, m, H_vs, s, e_vs, r, norm=True
n, m, H_vs, s, e_vs, r, emission_func, norm=True,
)
B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r)
B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r, emission_func)
self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m))
self.assertAllClose(ll_vs, ll_tmp)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_nontree_vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

class TestNonTreeViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 1
for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars(
ts,
ploidy=1,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
Expand Down

0 comments on commit 7a74c7f

Please sign in to comment.