Skip to content

Commit

Permalink
Fix test examples for diploid case
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed May 23, 2024
1 parent 9450d8e commit 64756c0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
18 changes: 12 additions & 6 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,29 @@ def get_examples_haploid(self, ts, include_ancestors):
return ref_panel, queries

def get_examples_diploid(self, ts, include_ancestors):
if include_ancestors is not None:
"""Both the reference panel and query contain unphased genotypes."""
if include_ancestors is True:
# TODO Handle NONCOPY properly.
raise NotImplementedError
ref_panel = ts.genotype_matrix()
num_sites = ref_panel.shape[0]
# Queries are unphased genotypes (calculated as allele dosages).
query1 = ref_panel[:, 0].reshape(1, num_sites)
query2 = ref_panel[:, -1].reshape(1, num_sites)
ref_panel = ref_panel[:, 1:-1]
query1 += ref_panel[:, 1].reshape(1, num_sites)
query2 = ref_panel[:, -2].reshape(1, num_sites)
query2 += ref_panel[:, -1].reshape(1, num_sites)
# Create queries with MISSING.
query_miss_last = query1.copy()
query_miss_last[0, -1] = core.MISSING
query_miss_mid = query1.copy()
query_miss_mid[0, ts.num_sites // 2] = core.MISSING
query_miss_all = query1.copy()
query_miss_all[0, :] = core.MISSING
queries = [query1, query2, query_miss_last, query_miss_mid, query_miss_all]
query_miss_most = query1.copy()
query_miss_most[0, 1:] = core.MISSING
queries = [query1, query2, query_miss_last, query_miss_mid, query_miss_most]
# Remove queries from the reference panel.
ref_panel = ref_panel[:, 2:-2]
ref_panel_size = ref_panel.shape[1]
# Reference panel contains unphased genotypes.
G = np.zeros((num_sites, ref_panel_size, ref_panel_size))
for i in range(num_sites):
G[i, :, :] = np.add.outer(ref_panel[i, :], ref_panel[i, :])
Expand Down
24 changes: 12 additions & 12 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,32 @@ def verify(
self.assertAllClose(B, B_vs)
self.assertAllClose(ll_vs, ll)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
ref_panel_size = 46
length = 1e5
Expand Down Expand Up @@ -217,32 +217,32 @@ def verify(
self.assertAllClose(ll_vs, ll)
self.assertAllClose(phased_path_vs, path)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
ref_panel_size = 46
length = 1e5
Expand Down
24 changes: 12 additions & 12 deletions tests/test_non_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,32 +126,32 @@ def verify(
self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m))
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
ref_panel_size = 45
length = 1e5
Expand Down Expand Up @@ -354,32 +354,32 @@ def verify(
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
ref_panel_size = 45
length = 1e5
Expand Down

0 comments on commit 64756c0

Please sign in to comment.