From 22a2e862128acf51fc8cf640fbb38f88f134ba14 Mon Sep 17 00:00:00 2001 From: szhan Date: Sat, 6 Jul 2024 22:35:45 +0100 Subject: [PATCH] Add tests for known breakpoints --- ...test_nontree_vit_haploid_fixed_switches.py | 243 ++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 tests/test_nontree_vit_haploid_fixed_switches.py diff --git a/tests/test_nontree_vit_haploid_fixed_switches.py b/tests/test_nontree_vit_haploid_fixed_switches.py new file mode 100644 index 0000000..33d0eac --- /dev/null +++ b/tests/test_nontree_vit_haploid_fixed_switches.py @@ -0,0 +1,243 @@ +import pytest + +import numpy as np +import numba as nb + +from . import lsbase +import lshmm.core as core +import lshmm.vit_haploid as vh + + +class TestNonTreeViterbiHaploidFixedSwitches(lsbase.ViterbiAlgorithmBase): + def get_examples_pars(): + # One fixed switch. + # Start + # Middle + # End + # Two fixed switches. + # + # Set the recombination rate to zero at all positions + # except the sites of the fixed switches. + # Check the paths that switching only occur there. + pass + + def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars( + ts, + ploidy=ploidy, + scale_mutation_rate=scale_mutation_rate, + include_ancestors=include_ancestors, + include_extreme_rates=True, + ): + emission_func = core.get_emission_probability_haploid + + # Implementation: naive + V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_vs = vh.backwards_viterbi_hap( + m=m, + V_last=V_vs[m - 1, :], + P=P_vs, + ) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + self.assertAllClose(ll_vs, ll_check) + + # Implementation: naive, vectorised + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap( + m=m, + V_last=V_tmp[m - 1, :], + P=P_tmp, + ) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + # Implementation: naive, low memory footprint + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + # Implementation: naive, low memory footprint, rescaling + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + # Implementation: low memory footprint, rescaling + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + # Implementation: even lower memory footprint, rescaling + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + ( + V_tmp, + V_argmaxes_tmp, + recombs, + ll_tmp, + ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( + n=n, + m=m, + H=H_vs, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + path_tmp = vh.backwards_viterbi_hap_no_pointer( + m=m, + V_argmaxes=V_argmaxes_tmp, + recombs=nb.typed.List(recombs), + ) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H_vs, + path=path_tmp, + s=s, + e=e_vs, + r=r, + emission_func=emission_func, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) + def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors): + ts = self.get_ts_simple_n10_no_recomb() + self.verify(ts, scale_mutation_rate, include_ancestors) + + @pytest.mark.parametrize("num_samples", [8, 16, 32]) + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) + def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors): + ts = self.get_ts_simple(num_samples) + self.verify(ts, scale_mutation_rate, include_ancestors) + + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) + def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors): + ts = self.get_ts_simple_n8_high_recomb() + self.verify(ts, scale_mutation_rate, include_ancestors) + + @pytest.mark.parametrize("scale_mutation_rate", [True, False]) + @pytest.mark.parametrize("include_ancestors", [True, False]) + def test_ts_larger(self, scale_mutation_rate, include_ancestors): + ts = self.get_ts_custom_pars( + num_samples=45, seq_length=1e5, mean_r=1e-5, mean_mu=1e-5 + ) + self.verify(ts, scale_mutation_rate, include_ancestors)