Skip to content

Commit

Permalink
Add tests for known breakpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 6, 2024
1 parent a6afc23 commit 22a2e86
Showing 1 changed file with 243 additions and 0 deletions.
243 changes: 243 additions & 0 deletions tests/test_nontree_vit_haploid_fixed_switches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import pytest

import numpy as np
import numba as nb

from . import lsbase
import lshmm.core as core
import lshmm.vit_haploid as vh


class TestNonTreeViterbiHaploidFixedSwitches(lsbase.ViterbiAlgorithmBase):
def get_examples_pars():
# One fixed switch.
# Start
# Middle
# End
# Two fixed switches.
#
# Set the recombination rate to zero at all positions
# except the sites of the fixed switches.
# Check the paths that switching only occur there.
pass

def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 1
for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars(
ts,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
emission_func = core.get_emission_probability_haploid

# Implementation: naive
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_vs = vh.backwards_viterbi_hap(
m=m,
V_last=V_vs[m - 1, :],
P=P_vs,
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
self.assertAllClose(ll_vs, ll_check)

# Implementation: naive, vectorised
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(
m=m,
V_last=V_tmp[m - 1, :],
P=P_tmp,
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: naive, low memory footprint
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: naive, low memory footprint, rescaling
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: low memory footprint, rescaling
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: even lower memory footprint, rescaling
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

(
V_tmp,
V_argmaxes_tmp,
recombs,
ll_tmp,
) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap_no_pointer(
m=m,
V_argmaxes=V_argmaxes_tmp,
recombs=nb.typed.List(recombs),
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("num_samples", [8, 16, 32])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple(num_samples)
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
num_samples=45, seq_length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(ts, scale_mutation_rate, include_ancestors)

0 comments on commit 22a2e86

Please sign in to comment.