Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 7, 2024
1 parent 741d203 commit 52ddab5
Showing 1 changed file with 84 additions and 61 deletions.
145 changes: 84 additions & 61 deletions tests/test_nontree_vit_haploid_fixed_switches.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,94 @@


class TestNonTreeViterbiHaploidFixedSwitches(lsbase.ViterbiAlgorithmBase):
def get_examples_pars():
num_sites = 10
def get_examples_pars(self, scale_mutation_rate):
# Set ref. panel and query.
# fmt: off
H = np.array(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0,],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1,],
],
dtype=np.int8,
).T
query = np.array(
[
[0, 0, 1, 1, 1, 1, 1, 1, 0, 0, ],
],
dtype=np.int8,
)
# fmt: on

n = H.shape[1] # Number of ref. haps
m = H.shape[0] # Number of sites

# Set constant mutation probability.
mu = np.zeros(num_sites, dtype=np.float64) + 1e-4
mu = np.zeros(m, dtype=np.float64) + 1e-4

# Set emission prob. matrix.
num_alleles = core.get_num_alleles(H, query)
e = core.get_emission_matrix_haploid(
mu=mu,
num_sites=m,
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)

# Set recombination probabilities for different sets of fixed switches.
recomb_rates_no_switch = np.zeros(num_sites, dtype=np.float64)
recomb_rates_one_switch_start = np.zeros_like(recomb_rates_no_switch)
recomb_rates_one_switch_start[0] = 1e-2
recomb_rates_one_switch_middle = np.zeros_like(recomb_rates_no_switch)
recomb_rates_one_switch_middle[num_sites / 2] = 1e-2
recomb_rates_one_switch_end = np.zeros_like(recomb_rates_no_switch)
recomb_rates_one_switch_end[-1] = 1e-2
recomb_rates_no_switch = np.zeros(m, dtype=np.float64)
recomb_rates_nonzero_start = np.zeros_like(recomb_rates_no_switch)
recomb_rates_nonzero_start[0] = 1e-2
recomb_rates_nonzero_mid = np.zeros_like(recomb_rates_no_switch)
recomb_rates_nonzero_mid[2] = 1e-2
recomb_rates_nonzero_end = np.zeros_like(recomb_rates_no_switch)
recomb_rates_nonzero_end[-1] = 1e-2
recomb_rates_two_switches = np.zeros_like(recomb_rates_no_switch)
recomb_rates_two_switches[3] = 1e-2
recomb_rates_two_switches[6] = 1e-2
recomb_rates_two_switches[2] = 1e-2
recomb_rates_two_switches[8] = 1e-2
recomb_rates_arr = [
recomb_rates_no_switch,
recomb_rates_one_switch_start,
recomb_rates_one_switch_middle,
recomb_rates_one_switch_end,
recomb_rates_nonzero_start,
recomb_rates_nonzero_mid,
recomb_rates_nonzero_end,
recomb_rates_two_switches,
]

# Expected paths
path_no_switch = None
path_one_switch_start = None
path_one_switch_middle = None
path_one_switch_end = None
path_two_switches = None

for r in recomb_rates_arr:
yield mu, r
# fmt: off
path_no_switch = np.array(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1,],
dtype=np.int8,
)
path_nonzero_start = np.array(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1,],
dtype=np.int8,
)
path_nonzero_mid = np.array(
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1,],
dtype=np.int8,
)
path_nonzero_end = np.array(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0,],
dtype=np.int8,
)
path_two_switches = np.array(
[0, 0, 1, 1, 1, 1, 1, 1, 0, 0,],
dtype=np.int8,
)
# fmt: on
paths_arr = [
path_no_switch,
path_nonzero_start,
path_nonzero_mid,
path_nonzero_end,
path_two_switches,
]

for r, path in zip(recomb_rates_arr, paths_arr):
yield n, m, H, query, e, r, path

def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 1
for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars(
ts,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
def verify(self, scale_mutation_rate):
for n, m, H_vs, s, e_vs, r, path in self.get_examples_pars(scale_mutation_rate):
emission_func = core.get_emission_probability_haploid

# Implementation: naive
Expand Down Expand Up @@ -81,6 +124,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_vs, path)
self.assertAllClose(ll_vs, ll_check)

# Implementation: naive, vectorised
Expand Down Expand Up @@ -108,6 +152,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

Expand All @@ -132,6 +177,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

Expand All @@ -156,6 +202,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

Expand All @@ -180,6 +227,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

Expand All @@ -204,6 +252,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

Expand Down Expand Up @@ -236,36 +285,10 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)


@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate, include_ancestors)


@pytest.mark.parametrize("num_samples", [8, 16, 32])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple(num_samples)
self.verify(ts, scale_mutation_rate, include_ancestors)


@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate, include_ancestors)


@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
num_samples=45, seq_length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(ts, scale_mutation_rate, include_ancestors)
def test_fixed_switches(self, scale_mutation_rate):
self.verify(scale_mutation_rate)

0 comments on commit 52ddab5

Please sign in to comment.