From 2ee2adf5eb9d7cdd02f5c13c79e4b60162686f44 Mon Sep 17 00:00:00 2001 From: szhan Date: Mon, 1 Jul 2024 12:26:27 +0100 Subject: [PATCH] Refactor tests for haploid case with multiallelic sites --- tests/lsbase.py | 34 ++----------------------- tests/test_api_fb_haploid_multi.py | 39 ++++++----------------------- tests/test_api_vit_haploid_multi.py | 33 +++++------------------- 3 files changed, 15 insertions(+), 91 deletions(-) diff --git a/tests/lsbase.py b/tests/lsbase.py index 402f08f..e0089b6 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -287,10 +287,10 @@ def get_ts_multiallelic_n10_no_recomb(self, seed=42): assert ts.num_sites > 3 return ts - def get_ts_multiallelic_n6(self, seed=42): + def get_ts_multiallelic(self, num_samples, seed=42): ts = msprime.sim_mutations( msprime.sim_ancestry( - samples=6, + samples=num_samples, recombination_rate=1e-4, sequence_length=40, population_size=1e4, @@ -299,40 +299,10 @@ def get_ts_multiallelic_n6(self, seed=42): rate=1e-3, random_seed=seed, ) - assert ts.num_sites > 5 - return ts - - def get_ts_multiallelic_n8(self, seed=42): - ts = msprime.sim_mutations( - msprime.sim_ancestry( - samples=8, - recombination_rate=1e-4, - sequence_length=20, - population_size=1e4, - random_seed=seed, - ), - rate=1e-4, - random_seed=seed, - ) assert ts.num_trees > 15 assert ts.num_sites > 5 return ts - def get_ts_multiallelic_n16(self, seed=42): - ts = msprime.sim_mutations( - msprime.sim_ancestry( - samples=16, - recombination_rate=1e-2, - sequence_length=20, - population_size=1e4, - random_seed=seed, - ), - rate=1e-4, - random_seed=seed, - ) - assert ts.num_sites > 5 - return ts - class ForwardBackwardAlgorithmBase(LSBase): """Base for testing forwards-backwards algorithms.""" diff --git a/tests/test_api_fb_haploid_multi.py b/tests/test_api_fb_haploid_multi.py index 86ad692..a90f57a 100644 --- a/tests/test_api_fb_haploid_multi.py +++ b/tests/test_api_fb_haploid_multi.py @@ -61,38 +61,13 @@ def test_ts_multiallelic_n10_no_recomb( self, scale_mutation_rate, include_ancestors ): ts = self.get_ts_multiallelic_n10_no_recomb() - self.verify( - ts, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - ) - - @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_multiallelic_n6(self, scale_mutation_rate, include_ancestors): - ts = self.get_ts_multiallelic_n6() - self.verify( - ts, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - ) + self.verify(ts, scale_mutation_rate, include_ancestors) + @pytest.mark.parametrize("num_samples", [6, 8, 16]) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_multiallelic_n8(self, scale_mutation_rate, include_ancestors): - ts = self.get_ts_multiallelic_n8() - self.verify( - ts, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - ) - - @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_multiallelic_n16(self, scale_mutation_rate, include_ancestors): - ts = self.get_ts_multiallelic_n16() - self.verify( - ts, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - ) + def test_ts_multiallelic_n16( + self, num_samples, scale_mutation_rate, include_ancestors + ): + ts = self.get_ts_multiallelic(num_samples) + self.verify(ts, scale_mutation_rate, include_ancestors) diff --git a/tests/test_api_vit_haploid_multi.py b/tests/test_api_vit_haploid_multi.py index e93e853..5020171 100644 --- a/tests/test_api_vit_haploid_multi.py +++ b/tests/test_api_vit_haploid_multi.py @@ -50,32 +50,11 @@ def test_ts_multiallelic_n10_no_recomb( include_ancestors=include_ancestors, ) + @pytest.mark.parametrize("num_samples", [6, 8, 16]) @pytest.mark.parametrize("scale_mutation_rate", [True, False]) @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_multiallelic_n6(self, scale_mutation_rate, include_ancestors): - ts = self.get_ts_multiallelic_n6() - self.verify( - ts, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - ) - - @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_multiallelic_n8(self, scale_mutation_rate, include_ancestors): - ts = self.get_ts_multiallelic_n8() - self.verify( - ts, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - ) - - @pytest.mark.parametrize("scale_mutation_rate", [True, False]) - @pytest.mark.parametrize("include_ancestors", [True, False]) - def test_ts_multiallelic_n16(self, scale_mutation_rate, include_ancestors): - ts = self.get_ts_multiallelic_n16() - self.verify( - ts, - scale_mutation_rate=scale_mutation_rate, - include_ancestors=include_ancestors, - ) + def test_ts_multiallelic_n16( + self, num_samples, scale_mutation_rate, include_ancestors + ): + ts = self.get_ts_multiallelic(num_samples) + self.verify(ts, scale_mutation_rate, include_ancestors)