diff --git a/lshmm/fb_haploid.py b/lshmm/fb_haploid.py index 89cab51..8f7b311 100644 --- a/lshmm/fb_haploid.py +++ b/lshmm/fb_haploid.py @@ -8,7 +8,7 @@ @jit.numba_njit def forwards_ls_hap( - n, m, H, s, e, r, norm=True, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, emission_func, *, norm=True, ): """ A matrix-based implementation using Numpy. @@ -82,7 +82,7 @@ def forwards_ls_hap( @jit.numba_njit def backwards_ls_hap( - n, m, H, s, e, c, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, c, r, *, emission_func, ): """ A matrix-based implementation using Numpy. diff --git a/lshmm/vit_haploid.py b/lshmm/vit_haploid.py index f608edb..b772324 100644 --- a/lshmm/vit_haploid.py +++ b/lshmm/vit_haploid.py @@ -8,7 +8,7 @@ @jit.numba_njit def viterbi_naive_init( - n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, *, emission_func, ): """Initialise a naive implementation.""" V = np.zeros((m, n)) @@ -30,7 +30,7 @@ def viterbi_naive_init( @jit.numba_njit def viterbi_init( - n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, *, emission_func, ): """Initialise a naive, but more memory efficient, implementation.""" V_prev = np.zeros(n) @@ -53,7 +53,7 @@ def viterbi_init( @jit.numba_njit def forwards_viterbi_hap_naive( - n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, *, emission_func, ): """A naive implementation of the forward pass.""" V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func) @@ -83,7 +83,7 @@ def forwards_viterbi_hap_naive( @jit.numba_njit def forwards_viterbi_hap_naive_vec( - n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, *, emission_func, ): """A naive matrix-based implementation of the forward pass.""" V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func) @@ -110,7 +110,7 @@ def forwards_viterbi_hap_naive_vec( @jit.numba_njit def forwards_viterbi_hap_naive_low_mem( - n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, *, emission_func, ): """A naive implementation of the forward pass with reduced memory.""" V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func) @@ -141,7 +141,7 @@ def forwards_viterbi_hap_naive_low_mem( @jit.numba_njit def forwards_viterbi_hap_naive_low_mem_rescaling( - n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, *, emission_func, ): """A naive implementation of the forward pass with reduced memory and rescaling.""" V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func) @@ -175,7 +175,7 @@ def forwards_viterbi_hap_naive_low_mem_rescaling( @jit.numba_njit def forwards_viterbi_hap_low_mem_rescaling( - n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, *, emission_func, ): """An implementation with reduced memory that exploits the Markov structure.""" V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func) @@ -208,7 +208,7 @@ def forwards_viterbi_hap_low_mem_rescaling( @jit.numba_njit def forwards_viterbi_hap_lower_mem_rescaling( - n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, *, emission_func, ): """ An implementation with even smaller memory footprint @@ -255,7 +255,7 @@ def forwards_viterbi_hap_lower_mem_rescaling( @jit.numba_njit def forwards_viterbi_hap_lower_mem_rescaling_no_pointer( - n, m, H, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, s, e, r, *, emission_func, ): """ An implementation with even smaller memory footprint and rescaling @@ -339,7 +339,7 @@ def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs): @jit.numba_njit def path_ll_hap( - n, m, H, path, s, e, r, *, emission_func=core.get_emission_probability_haploid + n, m, H, path, s, e, r, *, emission_func, ): """ Evaluate the log-likelihood of a path through a reference panel resulting in a query. diff --git a/tests/test_api_fb_haploid.py b/tests/test_api_fb_haploid.py index 7238283..1cf9656 100644 --- a/tests/test_api_fb_haploid.py +++ b/tests/test_api_fb_haploid.py @@ -16,7 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): include_ancestors=include_ancestors, include_extreme_rates=True, ): - num_alleles = core.get_num_alleles(H_vs, s) + emission_func = core.get_emission_probability_haploid F_vs, c_vs, ll_vs = fbh.forwards_ls_hap( n=n, m=m, @@ -24,6 +24,8 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): s=s, e=e_vs, r=r, + emission_func=emission_func, + norm=True, ) B_vs = fbh.backwards_ls_hap( n=n, @@ -33,6 +35,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): e=e_vs, c=c_vs, r=r, + emission_func=emission_func, ) F, c, ll = ls.forwards( reference_panel=H_vs, diff --git a/tests/test_api_fb_haploid_multi.py b/tests/test_api_fb_haploid_multi.py index a90f57a..77b1b6d 100644 --- a/tests/test_api_fb_haploid_multi.py +++ b/tests/test_api_fb_haploid_multi.py @@ -16,6 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): include_ancestors=include_ancestors, include_extreme_rates=True, ): + emission_func = core.get_emission_probability_haploid F_vs, c_vs, ll_vs = fbh.forwards_ls_hap( n=n, m=m, @@ -23,6 +24,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): s=s, e=e_vs, r=r, + emission_func=emission_func, ) B_vs = fbh.backwards_ls_hap( n=n, @@ -32,6 +34,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): e=e_vs, c=c_vs, r=r, + emission_func=emission_func, ) F, c, ll = ls.forwards( reference_panel=H_vs, diff --git a/tests/test_api_vit_haploid_multi.py b/tests/test_api_vit_haploid_multi.py index 5020171..d7dce8a 100644 --- a/tests/test_api_vit_haploid_multi.py +++ b/tests/test_api_vit_haploid_multi.py @@ -16,6 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): include_ancestors=include_ancestors, include_extreme_rates=True, ): + emission_func = core.get_emission_probability_haploid V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( n=n, m=m, @@ -23,9 +24,19 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): s=s, e=e_vs, r=r, + emission_func=emission_func, ) path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs, P=P_vs) - path_ll_hap = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) + path_ll_hap = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) path, ll = ls.viterbi( reference_panel=H_vs, query=s, @@ -44,11 +55,7 @@ def test_ts_multiallelic_n10_no_recomb( self, scale_mutation_rate, include_ancestors ): ts = self.get_ts_multiallelic_n10_no_recomb() - self.verify( - ts, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - ) + self.verify(ts, scale_mutation_rate, include_ancestors) @pytest.mark.parametrize("num_samples", [6, 8, 16]) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) diff --git a/tests/test_nontree_fb_haploid.py b/tests/test_nontree_fb_haploid.py index 7f76a0f..c8f277e 100644 --- a/tests/test_nontree_fb_haploid.py +++ b/tests/test_nontree_fb_haploid.py @@ -10,20 +10,22 @@ class TestNonTreeForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): - F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) - B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) + emission_func = core.get_emission_probability_haploid + F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, emission_func, norm=False) + B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r, emission_func) self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap( - n, m, H_vs, s, e_vs, r, norm=True + n, m, H_vs, s, e_vs, r, emission_func, norm=True, ) - B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) + B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r, emission_func) self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) self.assertAllClose(ll_vs, ll_tmp) diff --git a/tests/test_nontree_vit_haploid.py b/tests/test_nontree_vit_haploid.py index 93bd7a5..1b61887 100644 --- a/tests/test_nontree_vit_haploid.py +++ b/tests/test_nontree_vit_haploid.py @@ -10,9 +10,10 @@ class TestNonTreeViterbiHaploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True,