From 7813ed0edb1f7a678c765d89201f1bbb4dc1a9a7 Mon Sep 17 00:00:00 2001 From: szhan Date: Mon, 1 Jul 2024 18:57:24 +0100 Subject: [PATCH] Update test --- lshmm/fb_haploid.py | 2 -- lshmm/vit_haploid.py | 20 +++++--------------- tests/test_api_vit_haploid_multi.py | 1 + tests/test_nontree_vit_haploid.py | 2 +- 4 files changed, 7 insertions(+), 18 deletions(-) diff --git a/lshmm/fb_haploid.py b/lshmm/fb_haploid.py index 331c247..fa70f3b 100644 --- a/lshmm/fb_haploid.py +++ b/lshmm/fb_haploid.py @@ -15,7 +15,6 @@ def forwards_ls_hap( e, r, emission_func, - *, norm=True, ): """ @@ -97,7 +96,6 @@ def backwards_ls_hap( e, c, r, - *, emission_func, ): """ diff --git a/lshmm/vit_haploid.py b/lshmm/vit_haploid.py index 6bdfed0..cdcf833 100644 --- a/lshmm/vit_haploid.py +++ b/lshmm/vit_haploid.py @@ -14,7 +14,6 @@ def viterbi_naive_init( s, e, r, - *, emission_func, ): """Initialise a naive implementation.""" @@ -43,7 +42,6 @@ def viterbi_init( s, e, r, - *, emission_func, ): """Initialise a naive, but more memory efficient, implementation.""" @@ -73,11 +71,10 @@ def forwards_viterbi_hap_naive( s, e, r, - *, emission_func, ): """A naive implementation of the forward pass.""" - V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func) + V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func=emission_func) for j in range(1, m): for i in range(n): @@ -110,11 +107,10 @@ def forwards_viterbi_hap_naive_vec( s, e, r, - *, emission_func, ): """A naive matrix-based implementation of the forward pass.""" - V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func) + V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func=emission_func) for j in range(1, m): v_tmp = V[j - 1, :] * r_n[j] @@ -144,11 +140,10 @@ def forwards_viterbi_hap_naive_low_mem( s, e, r, - *, emission_func, ): """A naive implementation of the forward pass with reduced memory.""" - V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func) + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func=emission_func) for j in range(1, m): for i in range(n): @@ -182,11 +177,10 @@ def forwards_viterbi_hap_naive_low_mem_rescaling( s, e, r, - *, emission_func, ): """A naive implementation of the forward pass with reduced memory and rescaling.""" - V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func) + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func=emission_func) c = np.ones(m) for j in range(1, m): @@ -223,11 +217,10 @@ def forwards_viterbi_hap_low_mem_rescaling( s, e, r, - *, emission_func, ): """An implementation with reduced memory that exploits the Markov structure.""" - V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func) + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func=emission_func) c = np.ones(m) for j in range(1, m): @@ -263,7 +256,6 @@ def forwards_viterbi_hap_lower_mem_rescaling( s, e, r, - *, emission_func, ): """ @@ -317,7 +309,6 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer( s, e, r, - *, emission_func, ): """ @@ -409,7 +400,6 @@ def path_ll_hap( s, e, r, - *, emission_func, ): """ diff --git a/tests/test_api_vit_haploid_multi.py b/tests/test_api_vit_haploid_multi.py index d7dce8a..f2659f3 100644 --- a/tests/test_api_vit_haploid_multi.py +++ b/tests/test_api_vit_haploid_multi.py @@ -25,6 +25,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): e=e_vs, r=r, emission_func=emission_func, + norm=True, ) path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs, P=P_vs) path_ll_hap = vh.path_ll_hap( diff --git a/tests/test_nontree_vit_haploid.py b/tests/test_nontree_vit_haploid.py index 475d121..d681a0c 100644 --- a/tests/test_nontree_vit_haploid.py +++ b/tests/test_nontree_vit_haploid.py @@ -152,7 +152,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): emission_func=emission_func, ) path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + ll_check = vh.path_ll_hap(n=n, m=m, H=H_vs, path=path_tmp, s=s, e=e_vs, r=r, emission_func=emission_func,) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp)