From 5789876251421d3a8d82fa1b653ef11ce6154dca Mon Sep 17 00:00:00 2001 From: szhan Date: Sat, 20 Apr 2024 14:22:15 +0100 Subject: [PATCH] Major refactor --- tests/lsbase.py | 113 ++--- tests/test_LS_haploid_diploid.py | 717 ++++++++++++++----------------- 2 files changed, 392 insertions(+), 438 deletions(-) diff --git a/tests/lsbase.py b/tests/lsbase.py index 3f385c3..fbeca7f 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -13,9 +13,6 @@ class LSBase: def verify(self, ts): raise NotImplementedError() - def verify_larger(self, ts): - pass - def assertAllClose(self, A, B): np.testing.assert_allclose(A, B, rtol=1e-9, atol=0.0) @@ -70,24 +67,39 @@ def get_emission_matrix_haploid( e[j, 1] = 1 - mu[j] return e - def get_examples_pars_haploid(self, ts, scale_mutation=True, seed=42): + def get_examples_pars_haploid( + self, + ts, + mean_r=None, + mean_mu=None, + scale_mutation=True, + seed=42 + ): """Returns an iterator over combinations of examples and parameters.""" np.random.seed(seed) H, haplotypes = self.get_examples_haploid(ts) m = ts.num_sites n = H.shape[1] - rs = [ - np.zeros(m) + 0.01, # Equal recombination and mutation - np.zeros(m) + 0.999, # Extreme - np.zeros(m) + 1e-6, # Extreme - np.random.rand(m), # Random - ] - mus = [ - np.zeros(m) + 0.01, # Equal recombination and mutation - np.zeros(m) + 0.2, # Extreme - np.zeros(m) + 1e-6, # Extreme - np.random.rand(m) * 0.2, # Random - ] + if mean_r is not None and mean_mu is not None: + rs = [ + mean_r * (np.random.rand(m) + 0.5) / 2 + ] + mus = [ + mean_mu * (np.random.rand(m) + 0.5) / 2 + ] + else: + rs = [ + np.zeros(m) + 0.01, # Equal recombination and mutation + np.zeros(m) + 0.999, # Extreme + np.zeros(m) + 1e-6, # Extreme + np.random.rand(m), # Random + ] + mus = [ + np.zeros(m) + 0.01, # Equal recombination and mutation + np.zeros(m) + 0.2, # Extreme + np.zeros(m) + 1e-6, # Extreme + np.random.rand(m) * 0.2, # Random + ] for s, r, mu in itertools.product(haplotypes, rs, mus): r[0] = 0 # Must be calculated from the genotype matrix, @@ -101,8 +113,7 @@ def get_examples_pars_haploid(self, ts, scale_mutation=True, seed=42): yield n, m, H, s, e, r, mu # Diploid - def get_examples_diploid(self, ts, seed=42): - np.random.seed(seed) + def get_examples_diploid(self, ts): H = ts.genotype_matrix() s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) H = H[:, 2:] @@ -116,9 +127,10 @@ def get_examples_diploid(self, ts, seed=42): s_miss_mid[0, ts.num_sites // 2] = core.MISSING s_miss_all = s.copy() s_miss_all[0, :] = core.MISSING - genotypes.append(s_miss_last) - genotypes.append(s_miss_mid) - genotypes.append(s_miss_all) + # FIXME Handle MISSING properly. + #genotypes.append(s_miss_last) + #genotypes.append(s_miss_mid) + #genotypes.append(s_miss_all) m = ts.num_sites n = H.shape[1] G = np.zeros((m, n, n)) @@ -136,42 +148,43 @@ def get_emission_matrix_diploid(self, mu, m): e[:, core.MISSING_INDEX] = 1 return e - def get_examples_pars_diploid(self, ts, seed=42): + def get_examples_pars_diploid( + self, + ts, + mean_r=None, + mean_mu=None, + seed=42 + ): """Returns an iterator over combinations of examples and parameters.""" np.random.seed(seed) H, G, genotypes = self.get_examples_diploid(ts) m = ts.num_sites n = H.shape[1] - rs = [ - np.zeros(m) + 0.01, # Equal recombination and mutation - np.zeros(m) + 0.999, # Extreme - np.zeros(m) + 1e-6, # Extreme - np.random.rand(m), # Random - ] - mus = [ - np.zeros(m) + 0.01, # Equal recombination and mutation - np.zeros(m) + 0.33, # Extreme - np.zeros(m) + 1e-6, # Extreme - np.random.rand(m) * 0.33, # Random - ] + if mean_r is not None and mean_mu is not None: + rs = [ + mean_r * (np.random.rand(m) + 0.5) / 2 + ] + mus = [ + mean_mu * (np.random.rand(m) + 0.5) / 2 + ] + else: + rs = [ + np.zeros(m) + 0.01, # Equal recombination and mutation + np.zeros(m) + 0.999, # Extreme + np.zeros(m) + 1e-6, # Extreme + np.random.rand(m), # Random + ] + mus = [ + np.zeros(m) + 0.01, # Equal recombination and mutation + np.zeros(m) + 0.33, # Extreme + np.zeros(m) + 1e-6, # Extreme + np.random.rand(m) * 0.33, # Random + ] for s, r, mu in itertools.product(genotypes, rs, mus): r[0] = 0 e = self.get_emission_matrix_diploid(mu, m) yield n, m, G, s, e, r, mu - def get_examples_pars_larger_diploid(self, ts, mean_r=1e-5, mean_mu=1e-5, seed=42): - """Returns an iterator over combinations of examples and parameters.""" - np.random.seed(seed) - H, G, genotypes = self.get_examples_diploid(ts) - m = H.shape[0] - n = H.shape[1] - r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - r[0] = 0 - mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - e = self.get_emission_matrix_diploid(mu, m) - for s in genotypes: - yield n, m, G, s, e, r, mu - # Prepare simple example datasets. def get_simple_n10_no_recombination(self, seed=42): ts = msprime.simulate( @@ -291,10 +304,10 @@ def get_multiallelic_n16(self, seed=42): return ts # Prepare a larger example dataset. - def get_large(self, n=50, length=1e5, mean_r=1e-5, mean_mu=1e-5, seed=42): + def get_larger(self, num_samples, seq_length, mean_r, mean_mu, seed=42): ts = msprime.simulate( - n + 1, - length=length, + num_samples + 1, + length=seq_length, mutation_rate=mean_mu, recombination_rate=mean_r, random_seed=seed, diff --git a/tests/test_LS_haploid_diploid.py b/tests/test_LS_haploid_diploid.py index 5496242..76de81f 100644 --- a/tests/test_LS_haploid_diploid.py +++ b/tests/test_LS_haploid_diploid.py @@ -1,418 +1,359 @@ import numpy as np import numba as nb -import msprime - -import lsbase +from . import lsbase import lshmm.fb_diploid as fbd import lshmm.fb_haploid as fbh import lshmm.vit_diploid as vd import lshmm.vit_haploid as vh -class LSBase: - """Superclass of Li and Stephens tests.""" - - def haplotype_emission(self, mu, m): - e = np.zeros((m, 2)) - e[:, 0] = mu - e[:, 1] = 1 - mu - return e - - def example_parameters_haplotypes_larger( - self, ts, seed=42, mean_r=1e-5, mean_mu=1e-5 - ): - np.random.seed(seed) - H, haplotypes = self.example_haplotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - r[0] = 0 - - mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - - e = self.haplotype_emission(mu, m) - - for s in haplotypes: - yield n, m, H, s, e, r - - def example_parameters_genotypes_larger( - self, ts, seed=42, mean_r=1e-5, mean_mu=1e-5 - ): - np.random.seed(seed) - H, G, genotypes = self.example_genotypes(ts) - - m = ts.get_num_sites() - n = H.shape[1] - - r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - r[0] = 0 - - mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - - e = self.genotype_emission(mu, m) - - for s in genotypes: - yield n, m, G, s, e, r - - # Test a bigger one. - def test_large(self, n=50, length=1e5, mean_r=1e-5, mean_mu=1e-5, seed=42): - ts = msprime.simulate( - n + 1, - length=length, - mutation_rate=mean_mu, - recombination_rate=mean_r, - random_seed=seed, +class TestNonTreeMethodsHaploid(lsbase.FBAlgorithmBase): + def verify(self, n, m, H_vs, s, e_vs, r): + F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) + B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) + self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) + F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap( + n, m, H_vs, s, e_vs, r, norm=True ) - self.verify_larger(ts) - - -class TestNonTreeMethodsHap(lsbase.FBAlgorithmBase): - def verify(self, ts): - for n, m, H_vs, s, e_vs, r in self.get_examples_pars_haploid(ts): - F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) - B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) - F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap( - n, m, H_vs, s, e_vs, r, norm=True - ) - B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) - - def verify_larger(self, ts): - # variants x samples - for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes_larger(ts): - F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) - B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) - F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap( - n, m, H_vs, s, e_vs, r, norm=True - ) - B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) - - -class TestNonTreeMethodsDip(lsbase.FBAlgorithmBase): - def verify(self, ts): - for n, m, G_vs, s, e_vs, r in self.get_examples_pars_diploid(ts): - F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) - B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) - - F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip( - n, m, G_vs, s, e_vs, r, norm=False - ) - if ll_tmp != -np.inf: - B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=False - ) - if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=True - ) - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) - - def verify_larger(self, ts): - for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes_larger(ts): - F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) - B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) - F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip( - n, m, G_vs, s, e_vs, r, norm=False - ) - if ll_tmp != -np.inf: - B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=False - ) - if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=True - ) - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) - - -class TestNonTreeViterbiHap(lsbase.ViterbiAlgorithmBase): - def verify(self, ts): - for n, m, H_vs, s, e_vs, r in self.get_examples_pars_haploid(ts): - V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r) - path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs) - ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_check) + B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) + self.assertAllClose(ll_vs, ll_tmp) + self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) + + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n6(self): + ts = self.get_simple_n6() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8(self): + ts = self.get_simple_n8() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n16(self): + ts = self.get_simple_n16() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_larger(self): + seed=42 + num_samples = 50 + seq_length = 1e5 + mean_r = 1e-5 + mean_mu = 1e-5 + ts = self.get_larger( + num_samples, + seq_length, + mean_r, + mean_mu, + seed=seed, + ) + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid( + ts, mean_r=mean_r, mean_mu=mean_mu, seed=seed, + ): + self.verify(n, m, H_vs, s, e_vs, r) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) +class TestNonTreeMethodsDiploid(lsbase.FBAlgorithmBase): + def verify(self, n, m, G_vs, s, e_vs, r): + F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) + B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) + self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) + F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip( + n, m, G_vs, s, e_vs, r, norm=False + ) + if ll_tmp != -np.inf: + B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r) self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( - n, m, H_vs, s, e_vs, r + self.assertAllClose( + np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) + F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) + if ll_tmp != -np.inf: + B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) self.assertAllClose(ll_vs, ll_tmp) - - ( - V_tmp, - V_argmaxes_tmp, - recombs, - ll_tmp, - ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap_no_pointer( - m, - V_argmaxes_tmp, - nb.typed.List(recombs), + self.assertAllClose( + np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) ) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - def verify_larger(self, ts): - for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes_larger(ts): - V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r) - path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs) - ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_check) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) + F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( + n, m, G_vs, s, e_vs, r, norm=False + ) + if ll_tmp != -np.inf: + B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) self.assertAllClose(ll_vs, ll_tmp) - - ( - V_tmp, - V_argmaxes_tmp, - recombs, - ll_tmp, - ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( - n, m, H_vs, s, e_vs, r + self.assertAllClose( + np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) ) - path_tmp = vh.backwards_viterbi_hap_no_pointer( - m, V_argmaxes_tmp, nb.typed.List(recombs) - ) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - -class TestNonTreeViterbiDip(lsbase.ViterbiAlgorithmBase): - def verify(self, ts): - for n, m, G_vs, s, e_vs, r in self.get_examples_pars_diploid(ts): - V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r) - path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs) - phased_path_vs = vd.get_phased_path(n, path_vs) - path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, path_ll_vs) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) + F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( + n, m, G_vs, s, e_vs, r, norm=True + ) + B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) + self.assertAllClose(ll_vs, ll_tmp) + self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) + + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n6(self): + ts = self.get_simple_n6() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8(self): + ts = self.get_simple_n8() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n16(self): + ts = self.get_simple_n16() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_larger(self): + seed=42 + num_samples = 50 + seq_length = 1e5 + mean_r = 1e-5 + mean_mu = 1e-5 + ts = self.get_larger( + num_samples, + seq_length, + mean_r, + mean_mu, + seed=seed, + ) + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid( + ts, mean_r=mean_r, mean_mu=mean_mu, seed=seed, + ): + self.verify(n, m, H_vs, s, e_vs, r) - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - ( - V_tmp, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - recombs_single, - recombs_double, - ll_tmp, - ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r) - path_tmp = vd.backwards_viterbi_dip_no_pointer( - m, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - nb.typed.List(recombs_single), - nb.typed.List(recombs_double), - V_tmp, - ) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) +class TestNonTreeViterbiHaploid(lsbase.ViterbiAlgorithmBase): + def verify(self, n, m, H_vs, s, e_vs, r): + V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r) + path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs) + ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) + self.assertAllClose(ll_vs, ll_check) - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) - def verify_larger(self, ts): - for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes_larger(ts): - V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r) - path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs) - phased_path_vs = vd.get_phased_path(n, path_vs) - path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, path_ll_vs) + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) - ( - V_tmp, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - recombs_single, - recombs_double, - ll_tmp, - ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r) - path_tmp = vd.backwards_viterbi_dip_no_pointer( - m, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - nb.typed.List(recombs_single), - nb.typed.List(recombs_double), - V_tmp, - ) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + ( + V_tmp, + V_argmaxes_tmp, + recombs, + ll_tmp, + ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap_no_pointer( + m, + V_argmaxes_tmp, + nb.typed.List(recombs), + ) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n6(self): + ts = self.get_simple_n6() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8(self): + ts = self.get_simple_n8() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n16(self): + ts = self.get_simple_n16() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_larger(self): + seed=42 + num_samples = 50 + seq_length = 1e5 + mean_r = 1e-5 + mean_mu = 1e-5 + ts = self.get_larger( + num_samples, + seq_length, + mean_r, + mean_mu, + seed=seed, + ) + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid( + ts, mean_r=mean_r, mean_mu=mean_mu, seed=seed, + ): + self.verify(n, m, H_vs, s, e_vs, r) + + +class TestNonTreeViterbiDiploid(lsbase.ViterbiAlgorithmBase): + def verify(self, n, m, G_vs, s, e_vs, r): + V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r) + path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs) + phased_path_vs = vd.get_phased_path(n, path_vs) + path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) + self.assertAllClose(ll_vs, path_ll_vs) + + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + ( + V_tmp, + V_argmaxes_tmp, + V_rowcol_maxes_tmp, + V_rowcol_argmaxes_tmp, + recombs_single, + recombs_double, + ll_tmp, + ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r) + path_tmp = vd.backwards_viterbi_dip_no_pointer( + m, + V_argmaxes_tmp, + V_rowcol_maxes_tmp, + V_rowcol_argmaxes_tmp, + nb.typed.List(recombs_single), + nb.typed.List(recombs_double), + V_tmp, + ) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n6(self): + ts = self.get_simple_n6() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8(self): + ts = self.get_simple_n8() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n16(self): + ts = self.get_simple_n16() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_larger(self): + seed=42 + num_samples = 50 + seq_length = 1e5 + mean_r = 1e-5 + mean_mu = 1e-5 + ts = self.get_larger( + num_samples, + seq_length, + mean_r, + mean_mu, + seed=seed, + ) + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid( + ts, mean_r=mean_r, mean_mu=mean_mu, seed=seed, + ): + self.verify(n, m, H_vs, s, e_vs, r)