Skip to content

Commit

Permalink
Merge pull request #62 from szhan/add_tests_scale_mutation_rate_haploid
Browse files Browse the repository at this point in the history
Add tests for scaling mutation rates in the haploid case
  • Loading branch information
astheeggeggs authored Jun 1, 2024
2 parents 6215da9 + d6ffa12 commit b8a99aa
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 48 deletions.
100 changes: 76 additions & 24 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,39 +58,65 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
self.assertAllClose(B, B_vs)
self.assertAllClose(ll_vs, ll)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, include_ancestors):
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, include_ancestors):
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, include_ancestors):
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, include_ancestors):
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=True,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

Expand Down Expand Up @@ -242,39 +268,65 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
self.assertAllClose(ll_vs, ll)
self.assertAllClose(path_vs, path)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, include_ancestors):
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, include_ancestors):
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, include_ancestors):
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, include_ancestors):
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=46, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=True,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

Expand Down
100 changes: 76 additions & 24 deletions tests/test_non_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,59 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m))
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, include_ancestors):
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, include_ancestors):
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, include_ancestors):
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, include_ancestors):
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_larger(self, include_ancestors):
def test_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45,
length=1e5,
Expand All @@ -64,7 +90,7 @@ def test_larger(self, include_ancestors):
)
self.verify(
ts,
scale_mutation_rate=True,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

Expand Down Expand Up @@ -295,39 +321,65 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recombn(self, include_ancestors):
def test_ts_simple_n10_no_recombn(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, include_ancestors):
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, include_ancestors):
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, include_ancestors):
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, include_ancestors):
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(ts, scale_mutation_rate=True, include_ancestors=include_ancestors)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, include_ancestors):
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=True,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

Expand Down

0 comments on commit b8a99aa

Please sign in to comment.