From b6c9c71be3f6b96c415bb95088626fbe9bb5c853 Mon Sep 17 00:00:00 2001 From: szhan Date: Sat, 20 Apr 2024 17:10:13 +0100 Subject: [PATCH] Refactor --- lshmm/core.py | 4 +- lshmm/fb_haploid.py | 15 +- tests/lsbase.py | 4 +- tests/test_non_tree.py | 458 ++++++++++++++++++----------------------- 4 files changed, 213 insertions(+), 268 deletions(-) diff --git a/lshmm/core.py b/lshmm/core.py index 30f86d7..fc5b817 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -55,8 +55,8 @@ def np_argmax(array, axis): @jit.numba_njit -def get_index_in_emission_matrix(ref_allele, query_allele): - is_allele_match = np.equal(ref_allele, query_allele) +def get_index_in_emission_matrix_haploid(ref_allele, query_allele): + is_allele_match = ref_allele == query_allele is_query_missing = query_allele == MISSING if is_allele_match or is_query_missing: return 1 diff --git a/lshmm/fb_haploid.py b/lshmm/fb_haploid.py index 2541e02..28d67f5 100644 --- a/lshmm/fb_haploid.py +++ b/lshmm/fb_haploid.py @@ -11,14 +11,14 @@ @jit.numba_njit def forwards_ls_hap(n, m, H, s, e, r, norm=True): - """A matrix-based implementation using Numpy vectorisation.""" + """A matrix-based implementation using Numpy.""" F = np.zeros((m, n)) r_n = r / n if norm: c = np.zeros(m) for i in range(n): - emission_index = core.get_index_in_emission_matrix( + emission_index = core.get_index_in_emission_matrix_haploid( ref_allele=H[0, i], query_allele=s[0, 0] ) F[0, i] = 1 / n * e[0, emission_index] @@ -31,7 +31,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l] - emission_index = core.get_index_in_emission_matrix( + emission_index = core.get_index_in_emission_matrix_haploid( ref_allele=H[l, i], query_allele=s[0, l] ) F[l, i] *= e[l, emission_index] @@ -44,9 +44,8 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): else: c = np.ones(m) - for i in range(n): - emission_index = core.get_index_in_emission_matrix( + emission_index = core.get_index_in_emission_matrix_haploid( ref_allele=H[0, i], query_allele=s[0, 0] ) F[0, i] = 1 / n * e[0, emission_index] @@ -55,7 +54,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l] - emission_index = core.get_index_in_emission_matrix( + emission_index = core.get_index_in_emission_matrix_haploid( ref_allele=H[l, i], query_allele=s[0, l] ) F[l, i] *= e[l, emission_index] @@ -67,7 +66,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): @jit.numba_njit def backwards_ls_hap(n, m, H, s, e, c, r): - """A matrix-based implementation using Numpy vectorisation.""" + """A matrix-based implementation using Numpy.""" B = np.zeros((m, n)) for i in range(n): B[m - 1, i] = 1 @@ -78,7 +77,7 @@ def backwards_ls_hap(n, m, H, s, e, c, r): tmp_B = np.zeros(n) tmp_B_sum = 0 for i in range(n): - emission_index = core.get_index_in_emission_matrix( + emission_index = core.get_index_in_emission_matrix_haploid( ref_allele=H[l + 1, i], query_allele=s[0, l + 1] ) tmp_B[i] = e[l + 1, emission_index] * B[l + 1, i] diff --git a/tests/lsbase.py b/tests/lsbase.py index 0922f81..844309c 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -285,10 +285,10 @@ def get_multiallelic_n16(self, seed=42): return ts # Prepare a larger example dataset. - def get_larger(self, num_samples, seq_length, mean_r, mean_mu, seed=42): + def get_larger(self, num_samples, length, mean_r, mean_mu, seed=42): ts = msprime.simulate( num_samples + 1, - length=seq_length, + length=length, mutation_rate=mean_mu, recombination_rate=mean_r, random_seed=seed, diff --git a/tests/test_non_tree.py b/tests/test_non_tree.py index fa4c975..eb062e3 100644 --- a/tests/test_non_tree.py +++ b/tests/test_non_tree.py @@ -9,354 +9,300 @@ class TestNonTreeForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): - def verify(self, n, m, H_vs, s, e_vs, r): - F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) - B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) + def verify(self, ts, mean_r=None, mean_mu=None): + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid( + ts, mean_r=mean_r, mean_mu=mean_mu + ): + F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) + B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) + self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) - F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=True) - B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) + F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap( + n, m, H_vs, s, e_vs, r, norm=True + ) + B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) + self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) + self.assertAllClose(ll_vs, ll_tmp) def test_simple_n10_no_recombination(self): ts = self.get_simple_n10_no_recombination() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n6(self): ts = self.get_simple_n6() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n8(self): ts = self.get_simple_n8() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n8_high_recombination(self): ts = self.get_simple_n8_high_recombination() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n16(self): ts = self.get_simple_n16() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_larger(self): - seed = 42 - num_samples = 50 - seq_length = 1e5 mean_r = 1e-5 mean_mu = 1e-5 - ts = self.get_larger( - num_samples, - seq_length, - mean_r, - mean_mu, - seed=seed, - ) - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid( + ts = self.get_larger(50, length=1e5, mean_r=mean_r, mean_mu=mean_mu) + self.verify(ts, mean_r=mean_r, mean_mu=mean_mu) + + +class TestNonTreeForwardBackwardDiploid(lsbase.ForwardBackwardAlgorithmBase): + def verify(self, ts, mean_r=None, mean_mu=None): + for n, m, G_vs, s, e_vs, r, _ in self.get_examples_pars_diploid( ts, mean_r=mean_r, mean_mu=mean_mu, - seed=seed, ): - self.verify(n, m, H_vs, s, e_vs, r) + F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) + B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) + self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) - -class TestNonTreeForwardBackwardDiploid(lsbase.ForwardBackwardAlgorithmBase): - def verify(self, n, m, G_vs, s, e_vs, r): - F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) - B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) - - F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=False) - if ll_tmp != -np.inf: - B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) + F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip( + n, m, G_vs, s, e_vs, r, norm=False ) - - F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) + if ll_tmp != -np.inf: + B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r) + self.assertAllClose( + np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) + ) + self.assertAllClose(ll_vs, ll_tmp) + + F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) + if ll_tmp != -np.inf: + B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) + self.assertAllClose( + np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) + ) + self.assertAllClose(ll_vs, ll_tmp) + + F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( + n, m, G_vs, s, e_vs, r, norm=False + ) + if ll_tmp != -np.inf: + B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) + self.assertAllClose( + np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) + ) + self.assertAllClose(ll_vs, ll_tmp) + + F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( + n, m, G_vs, s, e_vs, r, norm=True ) - - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=False - ) - if ll_tmp != -np.inf: B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) + self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=True - ) - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) def test_simple_n10_no_recombination(self): ts = self.get_simple_n10_no_recombination() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n6(self): ts = self.get_simple_n6() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n8(self): ts = self.get_simple_n8() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n8_high_recombination(self): ts = self.get_simple_n8_high_recombination() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n16(self): ts = self.get_simple_n16() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_larger(self): - seed = 42 - num_samples = 50 - seq_length = 1e5 mean_r = 1e-5 mean_mu = 1e-5 - ts = self.get_larger( - num_samples, - seq_length, - mean_r, - mean_mu, - seed=seed, - ) - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid( - ts, - mean_r=mean_r, - mean_mu=mean_mu, - seed=seed, - ): - self.verify(n, m, H_vs, s, e_vs, r) + ts = self.get_larger(50, length=1e5, mean_r=mean_r, mean_mu=mean_mu) + self.verify(ts, mean_r=mean_r, mean_mu=mean_mu) class TestNonTreeViterbiHaploid(lsbase.ViterbiAlgorithmBase): - def verify(self, n, m, H_vs, s, e_vs, r): - V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r) - path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs) - ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_check) - - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec(n, m, H_vs, s, e_vs, r) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - - ( - V_tmp, - V_argmaxes_tmp, - recombs, - ll_tmp, - ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap_no_pointer( - m, - V_argmaxes_tmp, - nb.typed.List(recombs), - ) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) + def verify(self, ts, mean_r=None, mean_mu=None): + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid( + ts, mean_r=mean_r, mean_mu=mean_mu + ): + V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r) + path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs) + ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) + self.assertAllClose(ll_vs, ll_check) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + ( + V_tmp, + V_argmaxes_tmp, + recombs, + ll_tmp, + ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap_no_pointer( + m, + V_argmaxes_tmp, + nb.typed.List(recombs), + ) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) def test_simple_n10_no_recombination(self): ts = self.get_simple_n10_no_recombination() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n6(self): ts = self.get_simple_n6() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n8(self): ts = self.get_simple_n8() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n8_high_recombination(self): ts = self.get_simple_n8_high_recombination() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n16(self): ts = self.get_simple_n16() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_larger(self): - seed = 42 - num_samples = 50 - seq_length = 1e5 mean_r = 1e-5 mean_mu = 1e-5 - ts = self.get_larger( - num_samples, - seq_length, - mean_r, - mean_mu, - seed=seed, - ) - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid( + ts = self.get_larger(50, length=1e5, mean_r=mean_r, mean_mu=mean_mu) + self.verify(ts, mean_r=mean_r, mean_mu=mean_mu) + + +class TestNonTreeViterbiDiploid(lsbase.ViterbiAlgorithmBase): + def verify(self, ts, mean_r=None, mean_mu=None): + for n, m, G_vs, s, e_vs, r, _ in self.get_examples_pars_diploid( ts, mean_r=mean_r, mean_mu=mean_mu, - seed=seed, ): - self.verify(n, m, H_vs, s, e_vs, r) + V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r) + path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs) + phased_path_vs = vd.get_phased_path(n, path_vs) + path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) + self.assertAllClose(ll_vs, path_ll_vs) + + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) -class TestNonTreeViterbiDiploid(lsbase.ViterbiAlgorithmBase): - def verify(self, n, m, G_vs, s, e_vs, r): - V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r) - path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs) - phased_path_vs = vd.get_phased_path(n, path_vs) - path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, path_ll_vs) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - ( - V_tmp, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - recombs_single, - recombs_double, - ll_tmp, - ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r) - path_tmp = vd.backwards_viterbi_dip_no_pointer( - m, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - nb.typed.List(recombs_single), - nb.typed.List(recombs_double), - V_tmp, - ) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec(n, m, G_vs, s, e_vs, r) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) + ( + V_tmp, + V_argmaxes_tmp, + V_rowcol_maxes_tmp, + V_rowcol_argmaxes_tmp, + recombs_single, + recombs_double, + ll_tmp, + ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r) + path_tmp = vd.backwards_viterbi_dip_no_pointer( + m, + V_argmaxes_tmp, + V_rowcol_maxes_tmp, + V_rowcol_argmaxes_tmp, + nb.typed.List(recombs_single), + nb.typed.List(recombs_double), + V_tmp, + ) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) def test_simple_n10_no_recombination(self): ts = self.get_simple_n10_no_recombination() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n6(self): ts = self.get_simple_n6() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n8(self): ts = self.get_simple_n8() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n8_high_recombination(self): ts = self.get_simple_n8_high_recombination() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_simple_n16(self): ts = self.get_simple_n16() - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): - self.verify(n, m, H_vs, s, e_vs, r) + self.verify(ts) def test_larger(self): - seed = 42 - num_samples = 50 - seq_length = 1e5 mean_r = 1e-5 mean_mu = 1e-5 - ts = self.get_larger( - num_samples, - seq_length, - mean_r, - mean_mu, - seed=seed, - ) - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid( - ts, - mean_r=mean_r, - mean_mu=mean_mu, - seed=seed, - ): - self.verify(n, m, H_vs, s, e_vs, r) + ts = self.get_larger(50, length=1e5, mean_r=mean_r, mean_mu=mean_mu) + self.verify(ts, mean_r=mean_r, mean_mu=mean_mu)