Skip to content

Commit

Permalink
Merge pull request #44 from szhan/fix_tests_diploid
Browse files Browse the repository at this point in the history
Fix get examples for diploid case
  • Loading branch information
astheeggeggs authored May 24, 2024
2 parents 8a8eba4 + 167764a commit 0c2b7a8
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 42 deletions.
37 changes: 19 additions & 18 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,31 @@ def get_examples_haploid(self, ts, include_ancestors):
return ref_panel, queries

def get_examples_diploid(self, ts, include_ancestors):
# TODO Handle NONCOPY properly.
"""Both the reference panel and query contain unphased genotypes."""
if include_ancestors is True:
# TODO Handle NONCOPY properly.
raise NotImplementedError
ref_panel = ts.genotype_matrix()
num_sites = ref_panel.shape[0]
query1 = ref_panel[:, 0].reshape(1, num_sites) + ref_panel[:, 1].reshape(
1, num_sites
)
query2 = ref_panel[:, -1].reshape(1, num_sites) + ref_panel[:, -2].reshape(
1, num_sites
)
ref_panel = ref_panel[:, 2:]
# Create queries with MISSING
# Take some haplotypes as queries from the reference panel.
# Note that the queries contain unphased genotypes (calculated as allele dosages).
query1 = ref_panel[:, 0].reshape(1, num_sites)
query1 += ref_panel[:, 1].reshape(1, num_sites)
query2 = ref_panel[:, -2].reshape(1, num_sites)
query2 += ref_panel[:, -1].reshape(1, num_sites)
# Create queries with MISSING.
query_miss_last = query1.copy()
query_miss_last[0, -1] = core.MISSING
query_miss_mid = query1.copy()
query_miss_mid[0, ts.num_sites // 2] = core.MISSING
query_miss_all = query1.copy()
query_miss_all[0, :] = core.MISSING
queries = [query1, query2]
# FIXME Handle MISSING properly.
# genotypes.append(s_miss_last)
# genotypes.append(s_miss_mid)
# genotypes.append(s_miss_all)
ref_panel_size = ref_panel.shape[1]
G = np.zeros((num_sites, ref_panel_size, ref_panel_size))
query_miss_most = query1.copy()
query_miss_most[0, 1:] = core.MISSING
queries = [query1, query2, query_miss_last, query_miss_mid, query_miss_most]
# Exclude the arbitrarily chosen queries from the reference panel.
ref_panel = ref_panel[:, 2:-2]
num_ref_haps = ref_panel.shape[1] # Haplotypes, not individuals.
# Reference panel contains phased genotypes.
G = np.zeros((num_sites, num_ref_haps, num_ref_haps))
for i in range(num_sites):
G[i, :, :] = np.add.outer(ref_panel[i, :], ref_panel[i, :])
return ref_panel, G, queries
Expand Down
24 changes: 12 additions & 12 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,32 @@ def verify(
self.assertAllClose(B, B_vs)
self.assertAllClose(ll_vs, ll)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
ref_panel_size = 46
length = 1e5
Expand Down Expand Up @@ -217,32 +217,32 @@ def verify(
self.assertAllClose(ll_vs, ll)
self.assertAllClose(phased_path_vs, path)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
ref_panel_size = 46
length = 1e5
Expand Down
24 changes: 12 additions & 12 deletions tests/test_non_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,32 +126,32 @@ def verify(
self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m))
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
ref_panel_size = 45
length = 1e5
Expand Down Expand Up @@ -354,32 +354,32 @@ def verify(
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n6(self, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8(self, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_simple_n16(self, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)

@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("include_ancestors", [False])
def test_ts_larger(self, include_ancestors):
ref_panel_size = 45
length = 1e5
Expand Down

0 comments on commit 0c2b7a8

Please sign in to comment.