diff --git a/tests/lsbase.py b/tests/lsbase.py index 48f4186..cd5711a 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -71,30 +71,31 @@ def get_examples_haploid(self, ts, include_ancestors): return ref_panel, queries def get_examples_diploid(self, ts, include_ancestors): - # TODO Handle NONCOPY properly. + """Both the reference panel and query contain unphased genotypes.""" + if include_ancestors is True: + # TODO Handle NONCOPY properly. + raise NotImplementedError ref_panel = ts.genotype_matrix() num_sites = ref_panel.shape[0] - query1 = ref_panel[:, 0].reshape(1, num_sites) + ref_panel[:, 1].reshape( - 1, num_sites - ) - query2 = ref_panel[:, -1].reshape(1, num_sites) + ref_panel[:, -2].reshape( - 1, num_sites - ) - ref_panel = ref_panel[:, 2:] - # Create queries with MISSING + # Take some haplotypes as queries from the reference panel. + # Note that the queries contain unphased genotypes (calculated as allele dosages). + query1 = ref_panel[:, 0].reshape(1, num_sites) + query1 += ref_panel[:, 1].reshape(1, num_sites) + query2 = ref_panel[:, -2].reshape(1, num_sites) + query2 += ref_panel[:, -1].reshape(1, num_sites) + # Create queries with MISSING. query_miss_last = query1.copy() query_miss_last[0, -1] = core.MISSING query_miss_mid = query1.copy() query_miss_mid[0, ts.num_sites // 2] = core.MISSING - query_miss_all = query1.copy() - query_miss_all[0, :] = core.MISSING - queries = [query1, query2] - # FIXME Handle MISSING properly. - # genotypes.append(s_miss_last) - # genotypes.append(s_miss_mid) - # genotypes.append(s_miss_all) - ref_panel_size = ref_panel.shape[1] - G = np.zeros((num_sites, ref_panel_size, ref_panel_size)) + query_miss_most = query1.copy() + query_miss_most[0, 1:] = core.MISSING + queries = [query1, query2, query_miss_last, query_miss_mid, query_miss_most] + # Exclude the arbitrarily chosen queries from the reference panel. + ref_panel = ref_panel[:, 2:-2] + num_ref_haps = ref_panel.shape[1] # Haplotypes, not individuals. + # Reference panel contains phased genotypes. + G = np.zeros((num_sites, num_ref_haps, num_ref_haps)) for i in range(num_sites): G[i, :, :] = np.add.outer(ref_panel[i, :], ref_panel[i, :]) return ref_panel, G, queries diff --git a/tests/test_API.py b/tests/test_API.py index 63b6f6e..e872a25 100644 --- a/tests/test_API.py +++ b/tests/test_API.py @@ -92,32 +92,32 @@ def verify( self.assertAllClose(B, B_vs) self.assertAllClose(ll_vs, ll) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n10_no_recomb(self, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n6(self, include_ancestors): ts = self.get_ts_simple_n6() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n8(self, include_ancestors): ts = self.get_ts_simple_n8() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n8_high_recomb(self, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n16(self, include_ancestors): ts = self.get_ts_simple_n16() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_larger(self, include_ancestors): ref_panel_size = 46 length = 1e5 @@ -217,32 +217,32 @@ def verify( self.assertAllClose(ll_vs, ll) self.assertAllClose(phased_path_vs, path) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n10_no_recomb(self, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n6(self, include_ancestors): ts = self.get_ts_simple_n6() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n8(self, include_ancestors): ts = self.get_ts_simple_n8() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n8_high_recomb(self, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n16(self, include_ancestors): ts = self.get_ts_simple_n16() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_larger(self, include_ancestors): ref_panel_size = 46 length = 1e5 diff --git a/tests/test_non_tree.py b/tests/test_non_tree.py index b1d9d5f..fc6baee 100644 --- a/tests/test_non_tree.py +++ b/tests/test_non_tree.py @@ -126,32 +126,32 @@ def verify( self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) self.assertAllClose(ll_vs, ll_tmp) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n10_no_recomb(self, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n6(self, include_ancestors): ts = self.get_ts_simple_n6() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n8(self, include_ancestors): ts = self.get_ts_simple_n8() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n8_high_recomb(self, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n16(self, include_ancestors): ts = self.get_ts_simple_n16() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_larger(self, include_ancestors): ref_panel_size = 45 length = 1e5 @@ -354,32 +354,32 @@ def verify( self.assertAllClose(ll_tmp, path_ll_tmp) self.assertAllClose(ll_vs, ll_tmp) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n10_no_recomb(self, include_ancestors): ts = self.get_ts_simple_n10_no_recomb() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n6(self, include_ancestors): ts = self.get_ts_simple_n6() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n8(self, include_ancestors): ts = self.get_ts_simple_n8() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n8_high_recomb(self, include_ancestors): ts = self.get_ts_simple_n8_high_recomb() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_simple_n16(self, include_ancestors): ts = self.get_ts_simple_n16() self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors) - @pytest.mark.parametrize("include_ancestors", [True, False]) + @pytest.mark.parametrize("include_ancestors", [False]) def test_ts_larger(self, include_ancestors): ref_panel_size = 45 length = 1e5