Skip to content

Commit

Permalink
Refactor tests for haploid case with multiallelic sites
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 1, 2024
1 parent b095f74 commit c8b3963
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 91 deletions.
34 changes: 2 additions & 32 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ def get_ts_multiallelic_n10_no_recomb(self, seed=42):
assert ts.num_sites > 3
return ts

def get_ts_multiallelic_n6(self, seed=42):
def get_ts_multiallelic(self, num_samples, seed=42):
ts = msprime.sim_mutations(
msprime.sim_ancestry(
samples=6,
samples=num_samples,
recombination_rate=1e-4,
sequence_length=40,
population_size=1e4,
Expand All @@ -299,40 +299,10 @@ def get_ts_multiallelic_n6(self, seed=42):
rate=1e-3,
random_seed=seed,
)
assert ts.num_sites > 5
return ts

def get_ts_multiallelic_n8(self, seed=42):
ts = msprime.sim_mutations(
msprime.sim_ancestry(
samples=8,
recombination_rate=1e-4,
sequence_length=20,
population_size=1e4,
random_seed=seed,
),
rate=1e-4,
random_seed=seed,
)
assert ts.num_trees > 15
assert ts.num_sites > 5
return ts

def get_ts_multiallelic_n16(self, seed=42):
ts = msprime.sim_mutations(
msprime.sim_ancestry(
samples=16,
recombination_rate=1e-2,
sequence_length=20,
population_size=1e4,
random_seed=seed,
),
rate=1e-4,
random_seed=seed,
)
assert ts.num_sites > 5
return ts


class ForwardBackwardAlgorithmBase(LSBase):
"""Base for testing forwards-backwards algorithms."""
Expand Down
37 changes: 5 additions & 32 deletions tests/test_api_fb_haploid_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,38 +61,11 @@ def test_ts_multiallelic_n10_no_recomb(
self, scale_mutation_rate, include_ancestors
):
ts = self.get_ts_multiallelic_n10_no_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n6()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("num_samples", [6, 8, 16])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n8()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n16()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
def test_ts_multiallelic_n16(self, num_samples, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic(num_samples)
self.verify(ts, scale_mutation_rate, include_ancestors)
31 changes: 4 additions & 27 deletions tests/test_api_vit_haploid_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,9 @@ def test_ts_multiallelic_n10_no_recomb(
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("num_samples", [6, 8, 16])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n6()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n8()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n16()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
def test_ts_multiallelic_n16(self, num_samples, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic(num_samples)
self.verify(ts, scale_mutation_rate, include_ancestors)

0 comments on commit c8b3963

Please sign in to comment.