diff --git a/tests/test_nontree_vit_haploid_fixed_switches.py b/tests/test_nontree_vit_haploid_fixed_switches.py index 0b26a25..2753875 100644 --- a/tests/test_nontree_vit_haploid_fixed_switches.py +++ b/tests/test_nontree_vit_haploid_fixed_switches.py @@ -9,51 +9,94 @@ class TestNonTreeViterbiHaploidFixedSwitches(lsbase.ViterbiAlgorithmBase): - def get_examples_pars(): - num_sites = 10 + def get_examples_pars(self, scale_mutation_rate): + # Set ref. panel and query. + # fmt: off + H = np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + ], + dtype=np.int8, + ).T + query = np.array( + [ + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, ], + ], + dtype=np.int8, + ) + # fmt: on + + n = H.shape[1] # Number of ref. haps + m = H.shape[0] # Number of sites # Set constant mutation probability. - mu = np.zeros(num_sites, dtype=np.float64) + 1e-4 + mu = np.zeros(m, dtype=np.float64) + 1e-4 + + # Set emission prob. matrix. + num_alleles = core.get_num_alleles(H, query) + e = core.get_emission_matrix_haploid( + mu=mu, + num_sites=m, + num_alleles=num_alleles, + scale_mutation_rate=scale_mutation_rate, + ) # Set recombination probabilities for different sets of fixed switches. - recomb_rates_no_switch = np.zeros(num_sites, dtype=np.float64) - recomb_rates_one_switch_start = np.zeros_like(recomb_rates_no_switch) - recomb_rates_one_switch_start[0] = 1e-2 - recomb_rates_one_switch_middle = np.zeros_like(recomb_rates_no_switch) - recomb_rates_one_switch_middle[num_sites / 2] = 1e-2 - recomb_rates_one_switch_end = np.zeros_like(recomb_rates_no_switch) - recomb_rates_one_switch_end[-1] = 1e-2 + recomb_rates_no_switch = np.zeros(m, dtype=np.float64) + recomb_rates_nonzero_start = np.zeros_like(recomb_rates_no_switch) + recomb_rates_nonzero_start[0] = 1e-2 + recomb_rates_nonzero_mid = np.zeros_like(recomb_rates_no_switch) + recomb_rates_nonzero_mid[2] = 1e-2 + recomb_rates_nonzero_end = np.zeros_like(recomb_rates_no_switch) + recomb_rates_nonzero_end[-1] = 1e-2 recomb_rates_two_switches = np.zeros_like(recomb_rates_no_switch) - recomb_rates_two_switches[3] = 1e-2 - recomb_rates_two_switches[6] = 1e-2 + recomb_rates_two_switches[2] = 1e-2 + recomb_rates_two_switches[8] = 1e-2 recomb_rates_arr = [ recomb_rates_no_switch, - recomb_rates_one_switch_start, - recomb_rates_one_switch_middle, - recomb_rates_one_switch_end, + recomb_rates_nonzero_start, + recomb_rates_nonzero_mid, + recomb_rates_nonzero_end, recomb_rates_two_switches, ] # Expected paths - path_no_switch = None - path_one_switch_start = None - path_one_switch_middle = None - path_one_switch_end = None - path_two_switches = None - - for r in recomb_rates_arr: - yield mu, r + # fmt: off + path_no_switch = np.array( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + dtype=np.int8, + ) + path_nonzero_start = np.array( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + dtype=np.int8, + ) + path_nonzero_mid = np.array( + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1,], + dtype=np.int8, + ) + path_nonzero_end = np.array( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0,], + dtype=np.int8, + ) + path_two_switches = np.array( + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0,], + dtype=np.int8, + ) + # fmt: on + paths_arr = [ + path_no_switch, + path_nonzero_start, + path_nonzero_mid, + path_nonzero_end, + path_two_switches, + ] + for r, path in zip(recomb_rates_arr, paths_arr): + yield n, m, H, query, e, r, path - def verify(self, ts, scale_mutation_rate, include_ancestors): - ploidy = 1 - for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars( - ts, - ploidy=ploidy, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - include_extreme_rates=True, - ): + def verify(self, scale_mutation_rate): + for n, m, H_vs, s, e_vs, r, path in self.get_examples_pars(scale_mutation_rate): emission_func = core.get_emission_probability_haploid # Implementation: naive @@ -81,6 +124,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): r=r, emission_func=emission_func, ) + np.testing.assert_equal(path_vs, path) self.assertAllClose(ll_vs, ll_check) # Implementation: naive, vectorised @@ -108,6 +152,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): r=r, emission_func=emission_func, ) + np.testing.assert_equal(path_tmp, path) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) @@ -132,6 +177,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): r=r, emission_func=emission_func, ) + np.testing.assert_equal(path_tmp, path) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) @@ -156,6 +202,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): r=r, emission_func=emission_func, ) + np.testing.assert_equal(path_tmp, path) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) @@ -180,6 +227,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): r=r, emission_func=emission_func, ) + np.testing.assert_equal(path_tmp, path) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) @@ -204,6 +252,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): r=r, emission_func=emission_func, ) + np.testing.assert_equal(path_tmp, path) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) @@ -236,36 +285,10 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): r=r, emission_func=emission_func, ) + np.testing.assert_equal(path_tmp, path) self.assertAllClose(ll_tmp, ll_check) self.assertAllClose(ll_vs, ll_tmp) - - @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): - ts = self.get_ts_simple_n10_no_recomb() - self.verify(ts, scale_mutation_rate, include_ancestors) - - - @pytest.mark.parametrize("num_samples", [8, 16, 32]) - @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors): - ts = self.get_ts_simple(num_samples) - self.verify(ts, scale_mutation_rate, include_ancestors) - - - @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): - ts = self.get_ts_simple_n8_high_recomb() - self.verify(ts, scale_mutation_rate, include_ancestors) - - @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_larger(self, scale_mutation_rate, include_ancestors): - ts = self.get_ts_custom_pars( - num_samples=45, seq_length=1e5, mean_r=1e-5, mean_mu=1e-5 - ) - self.verify(ts, scale_mutation_rate, include_ancestors) + def test_fixed_switches(self, scale_mutation_rate): + self.verify(scale_mutation_rate)