Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for known breakpoints in Viterbi paths #145

Merged
merged 1 commit into from
Jul 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 294 additions & 0 deletions tests/test_nontree_vit_haploid_fixed_switches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import pytest

import numpy as np
import numba as nb

from . import lsbase
import lshmm.core as core
import lshmm.vit_haploid as vh


class TestNonTreeViterbiHaploidFixedSwitches(lsbase.ViterbiAlgorithmBase):
def get_examples_pars(self, scale_mutation_rate):
# Set ref. panel and query.
# fmt: off
H = np.array(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0,],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1,],
],
dtype=np.int8,
).T
query = np.array(
[
[0, 0, 1, 1, 1, 1, 1, 1, 0, 0, ],
],
dtype=np.int8,
)
# fmt: on

n = H.shape[1] # Number of ref. haps
m = H.shape[0] # Number of sites

# Set constant mutation probability.
mu = np.zeros(m, dtype=np.float64) + 1e-4

# Set emission prob. matrix.
num_alleles = core.get_num_alleles(H, query)
e = core.get_emission_matrix_haploid(
mu=mu,
num_sites=m,
num_alleles=num_alleles,
scale_mutation_rate=scale_mutation_rate,
)

# Set recombination probabilities for different sets of fixed switches.
recomb_rates_no_switch = np.zeros(m, dtype=np.float64)
recomb_rates_nonzero_start = np.zeros_like(recomb_rates_no_switch)
recomb_rates_nonzero_start[0] = 1e-2
recomb_rates_nonzero_mid = np.zeros_like(recomb_rates_no_switch)
recomb_rates_nonzero_mid[2] = 1e-2
recomb_rates_nonzero_end = np.zeros_like(recomb_rates_no_switch)
recomb_rates_nonzero_end[-1] = 1e-2
recomb_rates_two_switches = np.zeros_like(recomb_rates_no_switch)
recomb_rates_two_switches[2] = 1e-2
recomb_rates_two_switches[8] = 1e-2
recomb_rates_arr = [
recomb_rates_no_switch,
recomb_rates_nonzero_start,
recomb_rates_nonzero_mid,
recomb_rates_nonzero_end,
recomb_rates_two_switches,
]

# Expected paths
# fmt: off
path_no_switch = np.array(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1,],
dtype=np.int8,
)
path_nonzero_start = np.array(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1,],
dtype=np.int8,
)
path_nonzero_mid = np.array(
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1,],
dtype=np.int8,
)
path_nonzero_end = np.array(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0,],
dtype=np.int8,
)
path_two_switches = np.array(
[0, 0, 1, 1, 1, 1, 1, 1, 0, 0,],
dtype=np.int8,
)
# fmt: on
paths_arr = [
path_no_switch,
path_nonzero_start,
path_nonzero_mid,
path_nonzero_end,
path_two_switches,
]

for r, path in zip(recomb_rates_arr, paths_arr):
yield n, m, H, query, e, r, path

def verify(self, scale_mutation_rate):
for n, m, H_vs, s, e_vs, r, path in self.get_examples_pars(scale_mutation_rate):
emission_func = core.get_emission_probability_haploid

# Implementation: naive
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_vs = vh.backwards_viterbi_hap(
m=m,
V_last=V_vs[m - 1, :],
P=P_vs,
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_vs, path)
self.assertAllClose(ll_vs, ll_check)

# Implementation: naive, vectorised
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(
m=m,
V_last=V_tmp[m - 1, :],
P=P_tmp,
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: naive, low memory footprint
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: naive, low memory footprint, rescaling
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: low memory footprint, rescaling
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

# Implementation: even lower memory footprint, rescaling
V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

(
V_tmp,
V_argmaxes_tmp,
recombs,
ll_tmp,
) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
path_tmp = vh.backwards_viterbi_hap_no_pointer(
m=m,
V_argmaxes=V_argmaxes_tmp,
recombs=nb.typed.List(recombs),
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H_vs,
path=path_tmp,
s=s,
e=e_vs,
r=r,
emission_func=emission_func,
)
np.testing.assert_equal(path_tmp, path)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
def test_fixed_switches(self, scale_mutation_rate):
self.verify(scale_mutation_rate)
Loading