From 5ba84867c3da6f9694194131d752f6acbf5b33d0 Mon Sep 17 00:00:00 2001 From: Shing Zhan Date: Thu, 21 Mar 2024 16:40:27 +0000 Subject: [PATCH] Modify haploid Viterbi and forward-backward to handle NONCOPY state in reference panel --- lshmm/api.py | 39 ++-- lshmm/forward_backward/fb_haploid.py | 42 ++--- lshmm/vit_haploid.py | 70 ++++--- tests/test_API_noncopy.py | 256 ++++++++++++++++++++++++++ tests/test_API_noncopy_manual.py | 262 +++++++++++++++++++++++++++ 5 files changed, 609 insertions(+), 60 deletions(-) create mode 100644 tests/test_API_noncopy.py create mode 100644 tests/test_API_noncopy_manual.py diff --git a/lshmm/api.py b/lshmm/api.py index 5b88181..69ff87d 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -17,6 +17,9 @@ path_ll_hap, ) +MISSING = -1 +NONCOPY = -2 + EQUAL_BOTH_HOM = 4 UNEQUAL_BOTH_HOM = 0 BOTH_HET = 7 @@ -27,18 +30,27 @@ def check_alleles(alleles, m): """ - Checks the specified allele list and returns a list of lists - of alleles of length num_sites. - If alleles is a 1D list of strings, assume that this list is used - for each site and return num_sites copies of this list. - Otherwise, raise a ValueError if alleles is not a list of length - num_sites. + Checks the specified allele list and returns a list of allele lists of length m. + + If alleles is a 1D list of strings, assume that this list is used for each site + and return num_sites copies of this list. Otherwise, raise a ValueError + if alleles is not a list of length m. + + Note MISSING and NONCOPY values are excluded from the counts. + + :param list alleles: A list of lists of alleles or strings. + :param int m: Number of sites. + :return: An array of number of distinct alleles at each site. + :rtype: numpy.ndarray """ if isinstance(alleles[0], str): return np.int8([len(alleles) for _ in range(m)]) if len(alleles) != m: - raise ValueError("Malformed alleles list") - n_alleles = np.int8([(len(alleles_site)) for alleles_site in alleles]) + raise ValueError("Number of alleles list is not equal to number of sites.") + exclusion_set = np.array([MISSING, NONCOPY]) + n_alleles = np.zeros(m, dtype=np.int8) + for i in range(m): + n_alleles[i] = np.sum(~np.isin(np.unique(alleles[i]), exclusion_set)) return n_alleles @@ -132,12 +144,11 @@ def set_emission_probabilities( # Check alleles should go in here, and modify e before passing to the algorithm # If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel. if alleles is None: - n_alleles = np.int8( - [ - len(np.unique(np.append(reference_panel[j, :], query[:, j]))) - for j in range(reference_panel.shape[0]) - ] - ) + exclusion_set = np.array([MISSING, NONCOPY]) + n_alleles = np.zeros(m, dtype=np.int8) + for j in range(reference_panel.shape[0]): + uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j])) + n_alleles[j] = np.sum(~np.isin(uniq_alleles, exclusion_set)) else: n_alleles = check_alleles(alleles, m) diff --git a/lshmm/forward_backward/fb_haploid.py b/lshmm/forward_backward/fb_haploid.py index d0f03df..6dfefd6 100644 --- a/lshmm/forward_backward/fb_haploid.py +++ b/lshmm/forward_backward/fb_haploid.py @@ -4,6 +4,7 @@ from lshmm import jit MISSING = -1 +NONCOPY = -2 @jit.numba_njit @@ -17,9 +18,10 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): c = np.zeros(m) for i in range(n): - F[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - ) + em_prob = 0 + if H[0, i] != NONCOPY: + em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + F[0, i] = 1 / n * em_prob c[0] += F[0, i] for i in range(n): @@ -29,9 +31,10 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l] - F[l, i] *= e[ - l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING) - ] + em_prob = 0 + if H[l, i] != NONCOPY: + em_prob = e[l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING)] + F[l, i] *= em_prob c[l] += F[l, i] for i in range(n): @@ -44,17 +47,19 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): c = np.ones(m) for i in range(n): - F[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - ) + em_prob = 0 + if H[0, i] != NONCOPY: + em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + F[0, i] = 1 / n * em_prob # Forwards pass for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l] - F[l, i] *= e[ - l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING) - ] + em_prob = 0 + if H[l, i] != NONCOPY: + em_prob = e[l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING)] + F[l, i] *= em_prob ll = np.log10(np.sum(F[m - 1, :])) @@ -75,15 +80,10 @@ def backwards_ls_hap(n, m, H, s, e, c, r): tmp_B = np.zeros(n) tmp_B_sum = 0 for i in range(n): - tmp_B[i] = ( - e[ - l + 1, - np.int64( - np.equal(H[l + 1, i], s[0, l + 1]) or s[0, l + 1] == MISSING - ), - ] - * B[l + 1, i] - ) + em_prob = 0 + if H[l + 1, i] != NONCOPY: + em_prob = e[l + 1, np.int64(np.equal(H[l + 1, i], s[0, l + 1]) or s[0, l + 1] == MISSING)] + tmp_B[i] = em_prob * B[l + 1, i] tmp_B_sum += tmp_B[i] for i in range(n): B[l, i] = r_n[l + 1] * tmp_B_sum diff --git a/lshmm/vit_haploid.py b/lshmm/vit_haploid.py index 22aa7d3..9b2dd8f 100644 --- a/lshmm/vit_haploid.py +++ b/lshmm/vit_haploid.py @@ -4,6 +4,7 @@ from . import jit MISSING = -1 +NONCOPY = -2 @jit.numba_njit @@ -13,10 +14,10 @@ def viterbi_naive_init(n, m, H, s, e, r): P = np.zeros((m, n)).astype(np.int64) r_n = r / n for i in range(n): - V[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - ) - + em_prob = 0 + if H[0, i] != NONCOPY: + em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + V[0, i] = 1 / n * em_prob return V, P, r_n @@ -29,9 +30,10 @@ def viterbi_init(n, m, H, s, e, r): r_n = r / n for i in range(n): - V_previous[i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - ) + em_prob = 0 + if H[0, i] != NONCOPY: + em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + V_previous[i] = 1 / n * em_prob return V, V_previous, P, r_n @@ -47,10 +49,10 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r): # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V[j - 1, k] - ) + em_prob = 0 + if H[j, i] != NONCOPY: + em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + v[k] = em_prob * V[j - 1, k] if k == i: v[k] *= 1 - r[j] + r_n[j] else: @@ -74,7 +76,10 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): for i in range(n): v = np.copy(v_tmp) v[i] += V[j - 1, i] * (1 - r[j]) - v *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + em_prob = 0 + if H[j, i] != NONCOPY: + em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + v *= em_prob P[j, i] = np.argmax(v) V[j, i] = v[P[j, i]] @@ -94,10 +99,10 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V_previous[k] - ) + em_prob = 0 + if H[j, i] != NONCOPY: + em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + v[k] = (em_prob * V_previous[k]) if k == i: v[k] *= 1 - r[j] + r_n[j] else: @@ -125,10 +130,10 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V_previous[k] - ) + em_prob = 0 + if H[j, i] != NONCOPY: + em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + v[k] = em_prob * V_previous[k] if k == i: v[k] *= 1 - r[j] + r_n[j] else: @@ -161,7 +166,10 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + em_prob = 0 + if H[j, i] != NONCOPY: + em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + V[i] *= em_prob V_previous = np.copy(V) ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -175,7 +183,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): # Initialise V = np.zeros(n) for i in range(n): - V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + em_prob = 0 + if H[0, i] != NONCOPY: + em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + V[i] = 1 / n * em_prob P = np.zeros((m, n)).astype(np.int64) r_n = r / n c = np.ones(m) @@ -190,7 +201,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + em_prob = 0 + if H[j, i] != NONCOPY: + em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + V[i] *= em_prob ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -203,7 +217,10 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): # Initialise V = np.zeros(n) for i in range(n): - V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + em_prob = 0 + if H[0, i] != NONCOPY: + em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + V[i] = 1 / n * em_prob r_n = r / n c = np.ones(m) recombs = [ @@ -224,7 +241,10 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): recombs[j] = np.append( recombs[j], i ) # We add template i as a potential template to recombine to at site j. - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + em_prob = 0 + if H[j, i] != NONCOPY: + em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + V[i] *= em_prob V_argmaxes[m - 1] = np.argmax(V) ll = np.sum(np.log10(c)) + np.log10(np.max(V)) diff --git a/tests/test_API_noncopy.py b/tests/test_API_noncopy.py new file mode 100644 index 0000000..da651a8 --- /dev/null +++ b/tests/test_API_noncopy.py @@ -0,0 +1,256 @@ +import bisect +import itertools +import pytest + +import numpy as np + +import msprime +import tskit + +import lshmm as ls +import lshmm.forward_backward.fb_haploid as fbh +import lshmm.vit_haploid as vh + +MISSING = -1 +NONCOPY = -2 + + +# TODO: Either move tests in test_API_noncopy.py here or remove it altogether. + +class LSBase: + """Superclass of Li and Stephens tests.""" + + def get_ancestral_haplotypes(self, ts): + """ + Returns a numpy array of the haplotypes of the ancestors in the + specified tree sequence. + + Modified from + https://github.com/tskit-dev/tsinfer/blob/0c206d319f9c0dcb1ee205d5cc56576e3a88775e/tsinfer/eval_util.py#L244 + """ + tables = ts.dump_tables() + nodes = tables.nodes + flags = nodes.flags[:] + flags[:] = 1 + nodes.set_columns(time=nodes.time, flags=flags) + + sites = tables.sites.position + tsp = tables.tree_sequence() + B = tsp.genotype_matrix().T + + # Modified. Originally, this was filled with NONCOPY by default. + A = np.full((ts.num_nodes, ts.num_sites), NONCOPY, dtype=np.int8) + for edge in ts.edges(): + start = bisect.bisect_left(sites, edge.left) + end = bisect.bisect_right(sites, edge.right) + if sites[end - 1] == edge.right: + end -= 1 + A[edge.parent, start:end] = B[edge.parent, start:end] + A[: ts.num_samples] = B[: ts.num_samples] + + assert np.all(np.sum(A != NONCOPY, axis=0) > 0), \ + "Some sites have only NONCOPY states." + + return A.T + + def example_haplotypes(self, ts, num_random=10, seed=42): + H = self.get_ancestral_haplotypes(ts) + s = H[:, 0].reshape(1, H.shape[0]) + H = H[:, 1:] + + haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] + s_tmp = s.copy() + s_tmp[0, -1] = MISSING # End + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, ts.num_sites // 2] = MISSING + haplotypes.append(s_tmp) + s_tmp = s.copy() + s_tmp[0, :] = MISSING + haplotypes.append(s_tmp) + + return H, haplotypes + + def haplotype_emission(self, mu, m, n_alleles, scale_mutation_based_on_n_alleles): + # Define the emission probability matrix + e = np.zeros((m, 2)) + if isinstance(mu, float): + mu = mu * np.ones(m) + + if scale_mutation_based_on_n_alleles: + e[:, 0] = mu - mu * np.equal( + n_alleles, np.ones(m) + ) # Added boolean in case we're at an invariant site + e[:, 1] = 1 - (n_alleles - 1) * mu + else: + for j in range(m): + if n_alleles[j] == 1: # In case we're at an invariant site + e[j, 0] = 0 + e[j, 1] = 1 + else: + e[j, 0] = mu[j] / (n_alleles[j] - 1) + e[j, 1] = 1 - mu[j] + return e + + def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): + """Returns an iterator over combinations of haplotype, recombination and + mutation probabilities.""" + np.random.seed(seed) + H, haplotypes = self.example_haplotypes(ts) + n = H.shape[1] + m = ts.get_num_sites() + + # Here we have equal mutation and recombination + r = np.zeros(m) + 0.01 + mu = np.zeros(m) + 0.01 + r[0] = 0 + + def _get_n_states(H, s): + """ Get the number of states at each site. WIP. """ + assert H.shape[0] == s.shape[1] + m = H.shape[0] + n_states = np.zeros(m, dtype=np.int8) - 1 + exclude_set = np.array([MISSING, NONCOPY]) + for j in range(m): + proper_set = np.unique(np.append(H[j, :], s[:, j])) + n_states[j] = np.sum(~np.isin(proper_set, exclude_set)) + assert np.all(n_states >= 0) + return n_states + + for s in haplotypes: + # Must be calculated from the genotype matrix because we can now get back mutations that + # result in the number of alleles being higher than the number of alleles in the reference panel. + n_alleles = _get_n_states(H, s) + e = self.haplotype_emission( + mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation + ) + yield n, m, H, s, e, r, mu + + # Mixture of random and extremes + rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] + mus = [np.zeros(m) + 0.2, np.zeros(m) + 1e-6, np.random.rand(m) * 0.2] + + for s, r, mu in itertools.product(haplotypes, rs, mus): + r[0] = 0 + n_alleles = _get_n_states(H, s) + e = self.haplotype_emission( + mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation + ) + yield n, m, H, s, e, r, mu + + def assertAllClose(self, A, B): + """Assert that all entries of two matrices are 'close'""" + assert np.allclose(A, B, rtol=1e-9, atol=0.0) + + # Define a bunch of very small tree-sequences for testing a collection of parameters on + def test_simple_n_10_no_recombination(self): + ts = msprime.sim_ancestry( + samples=10, + recombination_rate=0, + random_seed=42, + sequence_length=10, + population_size=10000, + model=msprime.SmcApproxCoalescent(), + ) + ts = msprime.sim_mutations(ts, rate=1e-5, random_seed=42) + assert ts.num_sites > 3 + self.verify(ts) + + def test_simple_n_6(self): + ts = msprime.sim_ancestry( + samples=6, + recombination_rate=1e-4, + random_seed=42, + sequence_length=40, + population_size=10000, + model=msprime.SmcApproxCoalescent(), + ) + ts = msprime.sim_mutations(ts, rate=1e-3, random_seed=42) + assert ts.num_sites > 5 + self.verify(ts) + + def test_simple_n_8(self): + ts = msprime.sim_ancestry( + samples=8, + recombination_rate=1e-4, + random_seed=42, + sequence_length=20, + population_size=10000, + model=msprime.SmcApproxCoalescent(), + ) + ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) + assert ts.num_sites > 5 + assert ts.num_trees > 15 + self.verify(ts) + + def test_simple_n_16(self): + ts = msprime.sim_ancestry( + samples=16, + recombination_rate=1e-2, + random_seed=42, + sequence_length=20, + population_size=10000, + model=msprime.SmcApproxCoalescent(), + ) + ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) + assert ts.num_sites > 5 + self.verify(ts) + + def verify(self, ts): + raise NotImplementedError() + + +class FBAlgorithmBase(LSBase): + """Base for forwards backwards algorithm tests.""" + + +class TestMethodsHap(FBAlgorithmBase): + """Test that we compute the sample likelihoods across all implementations.""" + + def verify(self, ts): + for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r) + B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) + F, c, ll = ls.forwards(H_vs, s, r, p_mutation=mu) + B = ls.backwards(H_vs, s, c, r, p_mutation=mu) + self.assertAllClose(F, F_vs) + self.assertAllClose(B, B_vs) + # print(e_vs) + self.assertAllClose(ll_vs, ll) + + for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes( + ts, scale_mutation=False + ): + F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r) + B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) + F, c, ll = ls.forwards( + H_vs, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False + ) + B = ls.backwards( + H_vs, s, c, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False + ) + self.assertAllClose(F, F_vs) + self.assertAllClose(B, B_vs) + self.assertAllClose(ll_vs, ll) + + +class VitAlgorithmBase(LSBase): + """Base for viterbi algoritm tests.""" + + +class TestViterbiHap(VitAlgorithmBase): + """Test that we have the same log-likelihood across all implementations""" + + def verify(self, ts): + for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + + V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_vs = vh.backwards_viterbi_hap(m, V_vs, P_vs) + path_ll_hap = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) + path, ll = ls.viterbi(H_vs, s, r, p_mutation=mu) + + self.assertAllClose(ll_vs, ll) + self.assertAllClose(ll_vs, path_ll_hap) + self.assertAllClose(path_vs, path) diff --git a/tests/test_API_noncopy_manual.py b/tests/test_API_noncopy_manual.py new file mode 100644 index 0000000..1f9653b --- /dev/null +++ b/tests/test_API_noncopy_manual.py @@ -0,0 +1,262 @@ +import numpy as np +import pytest + +import lshmm.vit_haploid as vh + +MISSING = -1 +NONCOPY = -2 + + +# Helper functions +# TODO: Use the functions in the API instead. +def _get_emission_probabilities(m, p_mutation, n_alleles): + # Note that this is different than `set_emission_probabilities` in `api.py`. + # No scaling. + e = np.zeros((m, 2)) + for j in range(m): + if n_alleles[j] == 1: + e[j, 0] = 0 + e[j, 1] = 1 + else: + e[j, 0] = p_mutation[j] / (n_alleles[j] - 1) + e[j, 1] = 1 - p_mutation[j] + return e + + +def _get_num_alleles_per_site(H): + # Used to rescale mutation and recombination probabilities. + m = H.shape[0] # Number of sites + n_alleles = np.zeros(m, dtype=np.int64) - 1 + for i in range(m): + uniq_a = np.unique(H[i, :]) + assert len(uniq_a) > 0 + assert MISSING not in uniq_a + n_alleles[i] = np.sum(uniq_a != NONCOPY) + return n_alleles + + +# Prepare test data for testing. +def get_example_data(): + """ + Assumptions: + 1. Non-NONCOPY states are contiguous. + 2. No MISSING states in ref. panel. + """ + NC = NONCOPY # Sugar + # Trivial case 1 + H_trivial_1 = np.array([ + [NC, NC], + [ 0, 1], + ]).T + query_trivial_1 = np.array([[0, 1]]) + path_trivial_1 = np.array([1, 1]) + # Trivial case 2 + H_trivial_2 = np.array([ + [NC, 1], + [ 0, 0], + ]).T + query_trivial_2 = np.array([[0, 1]]) + path_trivial_2 = np.array([1, 0]) + # Only NONCOPY + H_only_noncopy = np.array([ + [NC, NC, NC, NC, NC, NC, NC, NC, NC, NC], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]).T + query_only_noncopy = np.array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) + path_only_noncopy = np.array([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + # NONCOPY on right + H_noncopy_on_right = np.array([ + [ 0, 0, 0, 0, 0, NC, NC, NC, NC, NC], + [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ]).T + query_noncopy_on_right = np.array([[ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]]) + path_noncopy_on_right = np.array([ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + # NONCOPY on left + H_noncopy_on_left = np.array([ + [NC, NC, NC, NC, NC, 0, 0, 0, 0, 0], + [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ]).T + query_noncopy_on_left = np.array([[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]]) + path_noncopy_on_left = np.array([ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) + # NONCOPY in middle + H_noncopy_middle = np.array([ + [NC, NC, NC, 0, 0, 0, 0, NC, NC, NC], + [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ]).T + query_noncopy_middle = np.array([[ 1, 1, 1, 0, 0, 0, 0, 1, 1, 1]]) + path_noncopy_middle = np.array([ 1, 1, 1, 0, 0, 0, 0, 1, 1, 1]) + # Two switches + H_two_switches = np.array([ + [ 0, 0, 0, NC, NC, NC, NC, NC, NC, NC], + [NC, NC, NC, 0, 0, 0, NC, NC, NC, NC], + [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ]).T + query_two_switches = np.array([[ 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) + path_two_switches = np.array([ 0, 0, 0, 1, 1, 1, 2, 2, 2, 2]) + # MISSING at switch position + # This causes more than one best paths + H_miss_switch = np.array([ + [NC, NC, NC, 0, 0, 0, 0, NC, NC, NC], + [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ]).T + query_miss_switch = np.array([[ 1, 1, 1, -1, 0, 0, 0, 1, 1, 1]]) + path_miss_switch = np.array([ 1, 1, 1, 1, 0, 0, 0, 1, 1, 1]) + # MISSING left of switch position + H_miss_next_switch = np.array([ + [NC, NC, NC, 0, 0, 0, 0, NC, NC, NC], + [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ]).T + query_next_switch = np.array([[ 1, 1, -1, 0, 0, 0, 0, 1, 1, 1]]) + path_next_switch = np.array([ 1, 1, 1, 0, 0, 0, 0, 1, 1, 1]) + + return [ + (H_trivial_1, query_trivial_1, path_trivial_1), + (H_trivial_2, query_trivial_2, path_trivial_2), + (H_only_noncopy, query_only_noncopy, path_only_noncopy), + (H_noncopy_on_right, query_noncopy_on_right, path_noncopy_on_right), + (H_noncopy_on_left, query_noncopy_on_left, path_noncopy_on_left), + (H_noncopy_middle, query_noncopy_middle, path_noncopy_middle), + (H_two_switches, query_two_switches, path_two_switches), + (H_miss_switch, query_miss_switch, path_miss_switch), + (H_miss_next_switch, query_next_switch, path_next_switch), + ] + + +# Tests for naive matrix-based implementation. +@pytest.mark.parametrize( + "H, s, expected_path", get_example_data() +) +def test_forwards_viterbi_hap_naive(H, s, expected_path): + m, n = H.shape + assert m == s.shape[1] == len(expected_path) + + r = np.zeros(m, dtype=np.float64) + 0.20 + p_mutation = np.zeros(m, dtype=np.float64) + 0.10 + + n_alleles = _get_num_alleles_per_site(H) + e = _get_emission_probabilities(m, p_mutation, n_alleles) + + _, _, actual_ll = vh.forwards_viterbi_hap_naive(n, m, H, s, e, r) + expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r) + + assert np.allclose(expected_ll, actual_ll) + + +# Tests for naive matrix-based implementation using numpy. +@pytest.mark.parametrize( + "H, s, expected_path", get_example_data() +) +def test_forwards_viterbi_hap_naive_vec(H, s, expected_path): + m, n = H.shape + assert m == s.shape[1] == len(expected_path) + + r = np.zeros(m, dtype=np.float64) + 0.20 + p_mutation = np.zeros(m, dtype=np.float64) + 0.10 + + n_alleles = _get_num_alleles_per_site(H) + e = _get_emission_probabilities(m, p_mutation, n_alleles) + + _, _, actual_ll = vh.forwards_viterbi_hap_naive_vec(n, m, H, s, e, r) + expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r) + + assert np.allclose(expected_ll, actual_ll) + + +# Tests for naive matrix-based implementation with reduced memory. +@pytest.mark.parametrize( + "H, s, expected_path", get_example_data() +) +def test_forwards_viterbi_hap_naive_low_mem(H, s, expected_path): + m, n = H.shape + assert m == s.shape[1] == len(expected_path) + + r = np.zeros(m, dtype=np.float64) + 0.20 + p_mutation = np.zeros(m, dtype=np.float64) + 0.10 + + n_alleles = _get_num_alleles_per_site(H) + e = _get_emission_probabilities(m, p_mutation, n_alleles) + + _, _, actual_ll = vh.forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r) + expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r) + + assert np.allclose(expected_ll, actual_ll), f"{expected_ll} {actual_ll}" + + +# Tests for naive matrix-based implementation with reduced memory and rescaling. +@pytest.mark.parametrize( + "H, s, expected_path", get_example_data() +) +def test_forwards_viterbi_hap_naive_low_mem_rescaling(H, s, expected_path): + m, n = H.shape + assert m == s.shape[1] == len(expected_path) + + r = np.zeros(m, dtype=np.float64) + 0.20 + p_mutation = np.zeros(m, dtype=np.float64) + 0.10 + + n_alleles = _get_num_alleles_per_site(H) + e = _get_emission_probabilities(m, p_mutation, n_alleles) + + _, _, actual_ll = vh.forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r) + expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r) + + assert np.allclose(expected_ll, actual_ll) + + +# Tests for implementation with reduced memory and rescaling. +@pytest.mark.parametrize( + "H, s, expected_path", get_example_data() +) +def test_forwards_viterbi_hap_low_mem_rescaling(H, s, expected_path): + m, n = H.shape + assert m == s.shape[1] == len(expected_path) + + r = np.zeros(m, dtype=np.float64) + 0.20 + p_mutation = np.zeros(m, dtype=np.float64) + 0.10 + + n_alleles = _get_num_alleles_per_site(H) + e = _get_emission_probabilities(m, p_mutation, n_alleles) + + _, _, actual_ll = vh.forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r) + expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r) + + assert np.allclose(expected_ll, actual_ll) + + +# Tests for implementation with even more reduced memory and rescaling. +@pytest.mark.parametrize( + "H, s, expected_path", get_example_data() +) +def test_forwards_viterbi_hap_lower_mem_rescaling(H, s, expected_path): + m, n = H.shape + assert m == s.shape[1] == len(expected_path) + + r = np.zeros(m, dtype=np.float64) + 0.20 + p_mutation = np.zeros(m, dtype=np.float64) + 0.10 + + n_alleles = _get_num_alleles_per_site(H) + e = _get_emission_probabilities(m, p_mutation, n_alleles) + + _, _, actual_ll = vh.forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r) + expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r) + + assert np.allclose(expected_ll, actual_ll) + + +# Tests for implementation with even more reduced memory and rescaling, without keeping pointers. +@pytest.mark.parametrize( + "H, s, expected_path", get_example_data() +) +def test_forwards_viterbi_hap_lower_mem_rescaling_no_pointer(H, s, expected_path): + m, n = H.shape + assert m == s.shape[1] == len(expected_path) + + r = np.zeros(m, dtype=np.float64) + 0.20 + p_mutation = np.zeros(m, dtype=np.float64) + 0.10 + + n_alleles = _get_num_alleles_per_site(H) + e = _get_emission_probabilities(m, p_mutation, n_alleles) + + _, _, _, actual_ll = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r) + expected_ll = vh.path_ll_hap(n, m, H, expected_path, s, e, r) + + assert np.allclose(expected_ll, actual_ll)