Skip to content

Commit

Permalink
Add tests for known breakpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 7, 2024
1 parent a6afc23 commit 86a74c9
Showing 1 changed file with 294 additions and 0 deletions.
294 changes: 294 additions & 0 deletions tests/test_nontree_vit_haploid_fixed_switches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import pytest

import numpy as np
import numba as nb

from . import lsbase
import lshmm.core as core
import lshmm.vit_haploid as vh


class TestNonTreeViterbiHaploidFixedSwitches(lsbase.ViterbiAlgorithmBase):
def get_examples_pars(self, scale_mutation_rate):
# Set ref. panel and query.
# fmt: off
H = np.array(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0,],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1,],
],
dtype=np.int8,
).T
query = np.array(
[
[0, 0, 1, 1, 1, 1, 1, 1, 0, 0, ],
],
dtype=np.int8,
)
# fmt: on

n = H.shape[1] # Number of ref. haps
m = H.shape[0] # Number of sites

# Set constant mutation probability.
mu = np.zeros(m, dtype=np.float64) + 1e-4

# Set emission prob. matrix.
num_alleles = core.get_num_alleles(H, query)
e = core.get_emission_matrix_haploid(
mu=mu,
num_sites=m,
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)

# Set recombination probabilities for different sets of fixed switches.
recomb_rates_no_switch = np.zeros(m, dtype=np.float64)
recomb_rates_nonzero_start = np.zeros_like(recomb_rates_no_switch)
recomb_rates_nonzero_start[0] = 1e-2
recomb_rates_nonzero_mid = np.zeros_like(recomb_rates_no_switch)
recomb_rates_nonzero_mid[2] = 1e-2
recomb_rates_nonzero_end = np.zeros_like(recomb_rates_no_switch)
recomb_rates_nonzero_end[-1] = 1e-2
recomb_rates_two_switches = np.zeros_like(recomb_rates_no_switch)
recomb_rates_two_switches[2] = 1e-2
recomb_rates_two_switches[8] = 1e-2
recomb_rates_arr = [
recomb_rates_no_switch,
recomb_rates_nonzero_start,
recomb_rates_nonzero_mid,
recomb_rates_nonzero_end,
recomb_rates_two_switches,
]

# Expected paths
# fmt: off
path_no_switch = np.array(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1,],
dtype=np.int8,
)
path_nonzero_start = np.array(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1,],
dtype=np.int8,
)
path_nonzero_mid = np.array(
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1,],
dtype=np.int8,
)
path_nonzero_end = np.array(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0,],
dtype=np.int8,
)
path_two_switches = np.array(
[0, 0, 1, 1, 1, 1, 1, 1, 0, 0,],
dtype=np.int8,
)
# fmt: on
paths_arr = [
path_no_switch,
path_nonzero_start,
path_nonzero_mid,
path_nonzero_end,
path_two_switches,
]

for r, path in zip(recomb_rates_arr, paths_arr):
yield n, m, H, query, e, r, path

def verify(self, scale_mutation_rate):
for n, m, H_vs, s, e_vs, r, path in self.get_examples_pars(scale_mutation_rate):
emission_func = core.get_emission_probability_haploid

# Implementation: naive
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_vs = vh.backwards_viterbi_hap(
m=m,
V_last=V_vs[m - 1, :],
P=P_vs,
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_vs, path)
self.assertAllClose(ll_vs, ll_check)

# Implementation: naive, vectorised
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(
m=m,
V_last=V_tmp[m - 1, :],
P=P_tmp,
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: naive, low memory footprint
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: naive, low memory footprint, rescaling
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: low memory footprint, rescaling
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: even lower memory footprint, rescaling
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

(
V_tmp,
V_argmaxes_tmp,
recombs,
ll_tmp,
) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap_no_pointer(
m=m,
V_argmaxes=V_argmaxes_tmp,
recombs=nb.typed.List(recombs),
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
def test_fixed_switches(self, scale_mutation_rate):
self.verify(scale_mutation_rate)

0 comments on commit 86a74c9

Please sign in to comment.