From f3a99eff862736623a9d60357fae9e1e495893aa Mon Sep 17 00:00:00 2001 From: szhan Date: Fri, 19 Apr 2024 09:02:34 +0100 Subject: [PATCH] Refactor --- lshmm/api.py | 108 ++-- lshmm/core.py | 85 +++ lshmm/{forward_backward => }/fb_diploid.py | 250 +++------ lshmm/{forward_backward => }/fb_haploid.py | 49 +- lshmm/forward_backward/__init__.py | 0 lshmm/vit_diploid.py | 331 ++++------- lshmm/vit_haploid.py | 157 +++--- tests/lsbase.py | 304 +++++++++++ tests/test_API.py | 339 +++--------- tests/test_API_multiallelic.py | 222 ++------ tests/test_LS_haploid_diploid.py | 607 --------------------- tests/test_non_tree.py | 362 ++++++++++++ 12 files changed, 1248 insertions(+), 1566 deletions(-) create mode 100644 lshmm/core.py rename lshmm/{forward_backward => }/fb_diploid.py (61%) rename lshmm/{forward_backward => }/fb_haploid.py (53%) delete mode 100644 lshmm/forward_backward/__init__.py create mode 100644 tests/lsbase.py delete mode 100644 tests/test_LS_haploid_diploid.py create mode 100644 tests/test_non_tree.py diff --git a/lshmm/api.py b/lshmm/api.py index f917ecf..f6e7946 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -4,8 +4,15 @@ import numpy as np -from .forward_backward.fb_diploid import backward_ls_dip_loop, forward_ls_dip_loop -from .forward_backward.fb_haploid import backwards_ls_hap, forwards_ls_hap +from . import core +from .fb_diploid import ( + backward_ls_dip_loop, + forward_ls_dip_loop, +) +from .fb_haploid import ( + backwards_ls_hap, + forwards_ls_hap, +) from .vit_diploid import ( backwards_viterbi_dip, forwards_viterbi_dip_low_mem, @@ -18,15 +25,6 @@ path_ll_hap, ) -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 -MISSING_INDEX = 3 - -MISSING = -1 - def check_alleles(alleles, m): """ @@ -52,7 +50,7 @@ def check_alleles(alleles, m): if isinstance(alleles[0], str): return np.int8([len(alleles) for _ in range(m)]) # Otherwise, process allele lists. - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) n_alleles = np.zeros(num_sites, dtype=np.int8) for i in range(num_sites): uniq_alleles = np.unique(alleles[i]) @@ -68,7 +66,7 @@ def checks( scale_mutation_based_on_n_alleles, ): """ - Checks that the input data and parameters are valid. + Check that the input data and parameters are valid. The reference panel must be a matrix of size (m, n) or (m, n, n). The query must be a matrix of size (k, m) or (k, m, 2). @@ -77,17 +75,21 @@ def checks( n = number of samples in the reference panel (haplotypes, not individuals). k = number of samples in the query (haplotypes, not individuals). + The mutation rate can be scaled according to the set of alleles + that can be mutated to based on the number of distinct alleles at each site. + :param numpy.ndarray(dtype=int) reference_panel: Matrix of size (m, n) or (m, n, n). :param numpy.ndarray(dtype=int) query: Matrix of size (k, m) or (k, m, 2). :param numpy.ndarray(dtype=float) p_mutation: Scalar or vector of length m. :param numpy.ndarray(dtype=float) p_recombination: Scalar or vector of length m. - :param bool scale_mutation_based_on_n_alleles: Whether to scale the mutation probability to the set of alleles that can be mutated to based on the number of alleles (True) or not (False). + :param bool scale_mutation_based_on_n_alleles: Scale the mutation probability or not. :return: n, m, ploidy :rtype: tuple """ # Check reference panel if not len(reference_panel.shape) in (2, 3): - raise ValueError("Reference panel array must have 2 or 3 dimensions.") + err_msg = "Reference panel array must have 2 or 3 dimensions." + raise ValueError(err_msg) if len(reference_panel.shape) == 2: m, n = reference_panel.shape @@ -97,42 +99,49 @@ def checks( ploidy = 2 if ploidy == 2 and (reference_panel.shape[1] != reference_panel.shape[2]): - raise ValueError( - "Reference_panel dimensions are incorrect, perhaps a sample x sample x variant matrix was passed. Expected sites x samples x samples." + err_msg = ( + "Reference_panel dimensions are incorrect, " + "perhaps a sample x sample x variant matrix was passed. " + "Expected sites x samples x samples." ) + raise ValueError(err_msg) # Check query sequence(s) if query.shape[1] != m: - raise ValueError( - "Number of sites in query does not match reference panel. If haploid, ensure a sites x samples matrix is passed." + err_msg = ( + "Number of sites in query does not match reference panel. " + "If haploid, ensure a sites x samples matrix is passed." ) + raise ValueError(err_msg) - # Ensure that the mutation probability is either a scalar or vector of length m + # Ensure that the mutation probability is either a scalar or vector of length m. if isinstance(p_mutation, (int, float)): if not scale_mutation_based_on_n_alleles: - warnings.warn( - "Passed a scalar probability of mutation, but not rescaling this probability of mutation conditional on the number of alleles at the site." - ) + warn_msg = "Passed a scalar mutation probability, but not rescaling it." + warnings.warn(warn_msg) elif isinstance(p_mutation, np.ndarray) and p_mutation.shape[0] == m: if scale_mutation_based_on_n_alleles: - warnings.warn( - "Passed a vector of probabilities of mutation, but rescaling each mutation probability conditional on the number of alleles at each site." - ) + warn_msg = "Passed a vector of mutation probabilities. Rescaling them." + warnings.warn(warn_msg) elif p_mutation is None: - warnings.warn( - "No mutation probability passed, setting mutation probability based on Li and Stephens 2003, equations (A2) and (A3)" + warn_msg = ( + "No mutation probability passed. " + "Setting it based on Li & Stephens (2003) equations A2 and A3." ) + warnings.warn(warn_msg) else: - raise ValueError( - f"Mutation probability is not None, a scalar, or vector of length m: {m}" + err_msg = ( + f"Mutation probability is not None, a scalar, or vector of length {m}." ) + raise ValueError(err_msg) # Ensure that the recombination probability is either a scalar or a vector of length m if not ( isinstance(p_recombination, (int, float)) or (isinstance(p_recombination, np.ndarray) and p_recombination.shape[0] == m) ): - raise ValueError(f"p_Recombination is not a scalar or vector of length m: {m}") + err_msg = f"Recombination probability is not a scalar or vector of length {m}." + raise ValueError(err_msg) return (n, m, ploidy) @@ -148,9 +157,10 @@ def set_emission_probabilities( scale_mutation_based_on_n_alleles, ): # Check alleles should go in here, and modify e before passing to the algorithm - # If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel. + # If alleles is not passed, we don't perform a test of alleles, + # but set n_alleles based on the reference_panel. if alleles is None: - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) n_alleles = np.zeros(m, dtype=np.int8) for j in range(reference_panel.shape[0]): uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j])) @@ -159,7 +169,7 @@ def set_emission_probabilities( n_alleles = check_alleles(alleles, m) if p_mutation is None: - # Set the mutation probability to be the proposed mutation probability in Li and Stephens (2003). + # Set the mutation probability to be the proposed mutation probability in Li & Stephens (2003). theta_tilde = 1 / np.sum([1 / k for k in range(1, n - 1)]) p_mutation = 0.5 * (theta_tilde / (n + theta_tilde)) @@ -172,15 +182,18 @@ def set_emission_probabilities( e = np.zeros((m, 2)) if scale_mutation_based_on_n_alleles: - # Scale mutation based on the number of alleles - so p_mutation is probability of mutation any given one of the alleles. + # Scale mutation based on the number of alleles, + # so p_mutation is probability of mutation any given one of the alleles. # The overall mutation probability is then (n_alleles - 1) * p_mutation. e[:, 0] = p_mutation - p_mutation * np.equal( n_alleles, np.ones(m) ) # Added boolean in case we're at an invariant site e[:, 1] = 1 - (n_alleles - 1) * p_mutation else: - # No scaling based on the number of alleles - so p_mutation is the probability of mutation to anything - # (summing over the states we can switch to). This means that we must rescale the probability of mutation to + # No scaling based on the number of alleles, + # so p_mutation is the probability of mutation to anything + # (summing over the states we can switch to). + # This means that we must rescale the probability of mutation to # a different allele by the number of alleles at the site. for j in range(m): if n_alleles[j] == 1: # In case we're at an invariant site @@ -194,12 +207,12 @@ def set_emission_probabilities( # Evaluate emission probabilities here, using the mutation probability - this can take a scalar or vector. # DEV: there's a wrinkle here. e = np.zeros((m, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2 - e[:, UNEQUAL_BOTH_HOM] = p_mutation**2 - e[:, BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation**2 - e[:, REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation) - e[:, REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation) - e[:, MISSING_INDEX] = 1 + e[:, core.EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2 + e[:, core.UNEQUAL_BOTH_HOM] = p_mutation**2 + e[:, core.BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation**2 + e[:, core.REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation) + e[:, core.REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation) + e[:, core.MISSING_INDEX] = 1 return e @@ -233,8 +246,7 @@ def forwards( norm=True, ): """ - Run the Li and Stephens forwards algorithm on haplotype or - unphased genotype data. + Run the Li & Stephens forwards algorithm on haplotype or unphased genotype data. """ n, m, ploidy = checks( reference_panel, @@ -281,8 +293,7 @@ def backwards( scale_mutation_based_on_n_alleles=True, ): """ - Run the Li and Stephens backwards algorithm on haplotype or - unphased genotype data. + Run the Li & Stephens backwards algorithm on haplotype or unphased genotype data. """ n, m, ploidy = checks( reference_panel, @@ -330,8 +341,7 @@ def viterbi( scale_mutation_based_on_n_alleles=True, ): """ - Run the Li and Stephens Viterbi algorithm on haplotype or - unphased genotype data. + Run the Li & Stephens Viterbi algorithm on haplotype or unphased genotype data. """ n, m, ploidy = checks( reference_panel, diff --git a/lshmm/core.py b/lshmm/core.py new file mode 100644 index 0000000..30f86d7 --- /dev/null +++ b/lshmm/core.py @@ -0,0 +1,85 @@ +import numpy as np + +from lshmm import jit + + +EQUAL_BOTH_HOM = 4 +UNEQUAL_BOTH_HOM = 0 +BOTH_HET = 7 +REF_HOM_OBS_HET = 1 +REF_HET_OBS_HOM = 2 +MISSING_INDEX = 3 + +MISSING = -1 + + +""" Helper functions. """ + + +# https://github.com/numba/numba/issues/1269 +@jit.numba_njit +def np_apply_along_axis(func1d, axis, arr): + """Create Numpy-like functions for max, sum, etc.""" + assert arr.ndim == 2 + assert axis in [0, 1] + if axis == 0: + result = np.empty(arr.shape[1]) + for i in range(len(result)): + result[i] = func1d(arr[:, i]) + else: + result = np.empty(arr.shape[0]) + for i in range(len(result)): + result[i] = func1d(arr[i, :]) + return result + + +@jit.numba_njit +def np_amax(array, axis): + """Numba implementation of Numpy-vectorised max.""" + return np_apply_along_axis(np.amax, axis, array) + + +@jit.numba_njit +def np_sum(array, axis): + """Numba implementation of Numpy-vectorised sum.""" + return np_apply_along_axis(np.sum, axis, array) + + +@jit.numba_njit +def np_argmax(array, axis): + """Numba implementation of Numpy-vectorised argmax.""" + return np_apply_along_axis(np.argmax, axis, array) + + +""" Functions used across different implementations of the LS HMM. """ + + +@jit.numba_njit +def get_index_in_emission_matrix(ref_allele, query_allele): + is_allele_match = np.equal(ref_allele, query_allele) + is_query_missing = query_allele == MISSING + if is_allele_match or is_query_missing: + return 1 + return 0 + + +@jit.numba_njit +def get_index_in_emission_matrix_diploid(ref_allele, query_allele): + if query_allele == MISSING: + return MISSING_INDEX + else: + is_match = ref_allele == query_allele + is_ref_one = ref_allele == 1 + is_query_one = query_allele == 1 + return 4 * is_match + 2 * is_ref_one + is_query_one + + +@jit.numba_njit +def get_index_in_emission_matrix_diploid_G(ref_G, query_allele, n): + if query_allele == MISSING: + return MISSING_INDEX * np.ones((n, n), dtype=np.int64) + else: + is_match = ref_G == query_allele + is_ref_one = ref_G == 1 + is_query_one = query_allele == 1 + return 4 * is_match + 2 * is_ref_one + is_query_one diff --git a/lshmm/forward_backward/fb_diploid.py b/lshmm/fb_diploid.py similarity index 61% rename from lshmm/forward_backward/fb_diploid.py rename to lshmm/fb_diploid.py index 50ffe12..919febb 100644 --- a/lshmm/forward_backward/fb_diploid.py +++ b/lshmm/fb_diploid.py @@ -1,74 +1,27 @@ -"""Collection of functions to run forwards and backwards algorithms on haploid genotype data, where the data is structured as variants x samples.""" +""" +Various implementations of the Li & Stephens forwards-backwards algorithm on diploid genotype data, +where the data is structured as variants x samples x samples. +""" + import numpy as np +from lshmm import core from lshmm import jit -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - - -# https://github.com/numba/numba/issues/1269 -@jit.numba_njit -def np_apply_along_axis(func1d, axis, arr): - """Create numpy-like functions for max, sum etc.""" - assert arr.ndim == 2 - assert axis in [0, 1] - if axis == 0: - result = np.empty(arr.shape[1]) - for i in range(len(result)): - result[i] = func1d(arr[:, i]) - else: - result = np.empty(arr.shape[0]) - for i in range(len(result)): - result[i] = func1d(arr[i, :]) - return result - - -@jit.numba_njit -def np_amax(array, axis): - """Numba implementation of numpy vectorised maximum.""" - return np_apply_along_axis(np.amax, axis, array) - - -@jit.numba_njit -def np_sum(array, axis): - """Numba implementation of numpy vectorised sum.""" - return np_apply_along_axis(np.sum, axis, array) - - -@jit.numba_njit -def np_argmax(array, axis): - """Numba implementation of numpy vectorised argmax.""" - return np_apply_along_axis(np.argmax, axis, array) - def forwards_ls_dip(n, m, G, s, e, r, norm=True): - """Matrix based diploid LS forward algorithm using numpy vectorisation.""" - # Initialise the forward tensor + """A matrix-based implementation using Numpy vectorisation.""" + # Initialise F = np.zeros((m, n, n)) F[0, :, :] = 1 / (n**2) c = np.ones(m) r_n = r / n - if s[0, 0] == MISSING: - index = MISSING_INDEX * np.ones( - (n, n), dtype=np.int64 - ) # We could have chosen anything here, this just implies a multiplication by a constant. - else: - index = 4 * np.equal(G[0, :, :], s[0, 0]).astype(np.int64) + 2 * ( - G[0, :, :] == 1 - ).astype(np.int64) - if s[0, 0] == 1: - index += 1 - - F[0, :, :] *= e[0, index] + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[0, :, :], query_allele=s[0, 0], n=n + ) + F[0, :, :] *= e[0, emission_index] if norm: c[0] = np.sum(F[0, :, :]) @@ -76,15 +29,9 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): # Forwards for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) + 2 * ( - G[l, :, :] == 1 - ).astype(np.int64) - - if s[0, l] == 1: - index += 1 + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[l, :, :], query_allele=s[0, l], n=n + ) # No change in both F[l, :, :] = (1 - r[l]) ** 2 * F[l - 1, :, :] @@ -93,11 +40,11 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): F[l, :, :] += (r_n[l]) ** 2 # One changes - sum_j = np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T + sum_j = core.np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T F[l, :, :] += ((1 - r[l]) * r_n[l]) * (sum_j + sum_j.T) # Emission - F[l, :, :] *= e[l, index] + F[l, :, :] *= e[l, emission_index] c[l] = np.sum(F[l, :, :]) F[l, :, :] *= 1 / c[l] @@ -105,15 +52,9 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): else: # Forwards for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) + 2 * ( - G[l, :, :] == 1 - ).astype(np.int64) - - if s[0, l] == 1: - index += 1 + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[l, :, :], query_allele=s[0, l], n=n + ) # No change in both F[l, :, :] = (1 - r[l]) ** 2 * F[l - 1, :, :] @@ -122,12 +63,12 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): F[l, :, :] += (r_n[l]) ** 2 * np.sum(F[l - 1, :, :]) # One changes - sum_j = np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T + sum_j = core.np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T # sum_j2 = np_sum(F[l - 1, :, :], 1).repeat(n).reshape((-1, n)) F[l, :, :] += ((1 - r[l]) * r_n[l]) * (sum_j + sum_j.T) # Emission - F[l, :, :] *= e[l, index] + F[l, :, :] *= e[l, emission_index] ll = np.log10(np.sum(F[l, :, :])) @@ -135,39 +76,36 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): def backwards_ls_dip(n, m, G, s, e, c, r): - """Matrix based diploid LS backward algorithm using numpy vectorisation.""" - # Initialise the backward tensor - B = np.zeros((m, n, n)) - + """A matrix-based implementation using Numpy vectorisation.""" # Initialise + B = np.zeros((m, n, n)) B[m - 1, :, :] = 1 r_n = r / n # Backwards for l in range(m - 2, -1, -1): - if s[0, l + 1] == MISSING: - index = MISSING_INDEX * np.ones( - (n, n), dtype=np.int64 - ) # We could have chosen anything here, this just implies a multiplication by a constant. - else: - index = ( - 4 * np.equal(G[l + 1, :, :], s[0, l + 1]).astype(np.int64) - + 2 * (G[l + 1, :, :] == 1).astype(np.int64) - + np.int64(s[0, l + 1] == 1) - ) + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[l + 1, :, :], query_allele=s[0, l + 1], n=n + ) # No change in both B[l, :, :] = r_n[l + 1] ** 2 * np.sum( - e[l + 1, index.reshape((n, n))] * B[l + 1, :, :] + e[l + 1, emission_index.reshape((n, n))] * B[l + 1, :, :] ) # Both change B[l, :, :] += ( - (1 - r[l + 1]) ** 2 * B[l + 1, :, :] * e[l + 1, index.reshape((n, n))] + (1 - r[l + 1]) ** 2 + * B[l + 1, :, :] + * e[l + 1, emission_index.reshape((n, n))] ) # One changes - sum_j = np_sum(B[l + 1, :, :] * e[l + 1, index], 0).repeat(n).reshape((-1, n)) + sum_j = ( + core.np_sum(B[l + 1, :, :] * e[l + 1, emission_index], 0) + .repeat(n) + .reshape((-1, n)) + ) B[l, :, :] += ((1 - r[l + 1]) * r_n[l + 1]) * (sum_j + sum_j.T) B[l, :, :] *= 1 / c[l + 1] @@ -177,24 +115,19 @@ def backwards_ls_dip(n, m, G, s, e, c, r): @jit.numba_njit def forward_ls_dip_starting_point(n, m, G, s, e, r): """Naive implementation of LS diploid forwards algorithm.""" - # Initialise the forward tensor + # Initialise F = np.zeros((m, n, n)) r_n = r / n + for j1 in range(n): for j2 in range(n): F[0, j1, j2] = 1 / (n**2) - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - F[0, j1, j2] *= e[0, index_tmp] + emission_index = core.get_index_in_emission_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + F[0, j1, j2] *= e[0, emission_index] for l in range(1, m): - # Determine the various components F_no_change = np.zeros((n, n)) F_j1_change = np.zeros(n) F_j2_change = np.zeros(n) @@ -231,24 +164,24 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): # What is the emission? - if s[0, l] == MISSING: - F[l, j1, j2] *= e[l, MISSING_INDEX] + if s[0, l] == core.MISSING: + F[l, j1, j2] *= e[l, core.MISSING_INDEX] else: if s[0, l] == 1: # OBS is het if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, BOTH_HET] + F[l, j1, j2] *= e[l, core.BOTH_HET] else: # REF is hom - F[l, j1, j2] *= e[l, REF_HOM_OBS_HET] + F[l, j1, j2] *= e[l, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, REF_HET_OBS_HOM] + F[l, j1, j2] *= e[l, core.REF_HET_OBS_HOM] else: # REF is hom if G[l, j1, j2] == s[0, l]: # Equal - F[l, j1, j2] *= e[l, EQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.EQUAL_BOTH_HOM] else: # Unequal - F[l, j1, j2] *= e[l, UNEQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.UNEQUAL_BOTH_HOM] ll = np.log10(np.sum(F[l, :, :])) @@ -257,16 +190,13 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r): @jit.numba_njit def backward_ls_dip_starting_point(n, m, G, s, e, r): - """Naive implementation of LS diploid backwards algorithm.""" - # Backwards - B = np.zeros((m, n, n)) - + """A naive implementation.""" # Initialise + B = np.zeros((m, n, n)) B[m - 1, :, :] = 1 r_n = r / n for l in range(m - 2, -1, -1): - # Determine the various components B_no_change = np.zeros((n, n)) B_j1_change = np.zeros(n) B_j2_change = np.zeros(n) @@ -274,8 +204,8 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): # Evaluate the emission matrix at this site, for all pairs e_tmp = np.zeros((n, n)) - if s[0, l + 1] == MISSING: - e_tmp[:, :] = e[l + 1, MISSING_INDEX] + if s[0, l + 1] == core.MISSING: + e_tmp[:, :] = e[l + 1, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -283,18 +213,18 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): if s[0, l + 1] == 1: # OBS is het if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, BOTH_HET] + e_tmp[j1, j2] = e[l + 1, core.BOTH_HET] else: # REF is hom - e_tmp[j1, j2] = e[l + 1, REF_HOM_OBS_HET] + e_tmp[j1, j2] = e[l + 1, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, REF_HET_OBS_HOM] + e_tmp[j1, j2] = e[l + 1, core.REF_HET_OBS_HOM] else: # REF is hom if G[l + 1, j1, j2] == s[0, l + 1]: # Equal - e_tmp[j1, j2] = e[l + 1, EQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.EQUAL_BOTH_HOM] else: # Unequal - e_tmp[j1, j2] = e[l + 1, UNEQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.UNEQUAL_BOTH_HOM] for j1 in range(n): for j2 in range(n): @@ -335,21 +265,16 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): @jit.numba_njit def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): - """LS diploid forwards algoritm without vectorisation.""" - # Initialise the forward tensor + """LS diploid forwards algorithm without vectorisation.""" + # Initialise F = np.zeros((m, n, n)) for j1 in range(n): for j2 in range(n): F[0, j1, j2] = 1 / (n**2) - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - F[0, j1, j2] *= e[0, index_tmp] + emission_index = core.get_index_in_emission_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + F[0, j1, j2] *= e[0, emission_index] r_n = r / n c = np.ones(m) @@ -358,7 +283,6 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): F[0, :, :] *= 1 / c[0] for l in range(1, m): - # Determine the various components F_no_change = np.zeros((n, n)) F_j_change = np.zeros(n) @@ -375,8 +299,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): for j2 in range(n): F[l, j1, j2] += F_no_change[j1, j2] - if s[0, l] == MISSING: - F[l, :, :] *= e[l, MISSING_INDEX] + if s[0, l] == core.MISSING: + F[l, :, :] *= e[l, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -384,18 +308,18 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): if s[0, l] == 1: # OBS is het if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, BOTH_HET] + F[l, j1, j2] *= e[l, core.BOTH_HET] else: # REF is hom - F[l, j1, j2] *= e[l, REF_HOM_OBS_HET] + F[l, j1, j2] *= e[l, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, REF_HET_OBS_HOM] + F[l, j1, j2] *= e[l, core.REF_HET_OBS_HOM] else: # REF is hom if G[l, j1, j2] == s[0, l]: # Equal - F[l, j1, j2] *= e[l, EQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.EQUAL_BOTH_HOM] else: # Unequal - F[l, j1, j2] *= e[l, UNEQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.UNEQUAL_BOTH_HOM] c[l] = np.sum(F[l, :, :]) F[l, :, :] *= 1 / c[l] @@ -404,7 +328,6 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): else: for l in range(1, m): - # Determine the various components F_no_change = np.zeros((n, n)) F_j1_change = np.zeros(n) F_j2_change = np.zeros(n) @@ -425,8 +348,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): for j2 in range(n): F[l, j1, j2] += F_no_change[j1, j2] - if s[0, l] == MISSING: - F[l, :, :] *= e[l, MISSING_INDEX] + if s[0, l] == core.MISSING: + F[l, :, :] *= e[l, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -434,18 +357,18 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): if s[0, l] == 1: # OBS is het if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, BOTH_HET] + F[l, j1, j2] *= e[l, core.BOTH_HET] else: # REF is hom - F[l, j1, j2] *= e[l, REF_HOM_OBS_HET] + F[l, j1, j2] *= e[l, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, REF_HET_OBS_HOM] + F[l, j1, j2] *= e[l, core.REF_HET_OBS_HOM] else: # REF is hom if G[l, j1, j2] == s[0, l]: # Equal - F[l, j1, j2] *= e[l, EQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.EQUAL_BOTH_HOM] else: # Unequal - F[l, j1, j2] *= e[l, UNEQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.UNEQUAL_BOTH_HOM] ll = np.log10(np.sum(F[l, :, :])) @@ -455,13 +378,12 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): @jit.numba_njit def backward_ls_dip_loop(n, m, G, s, e, c, r): """LS diploid backwards algoritm without vectorisation.""" - # Initialise the backward tensor + # Initialise B = np.zeros((m, n, n)) B[m - 1, :, :] = 1 r_n = r / n for l in range(m - 2, -1, -1): - # Determine the various components B_no_change = np.zeros((n, n)) B_j1_change = np.zeros(n) B_j2_change = np.zeros(n) @@ -469,8 +391,8 @@ def backward_ls_dip_loop(n, m, G, s, e, c, r): # Evaluate the emission matrix at this site, for all pairs e_tmp = np.zeros((n, n)) - if s[0, l + 1] == MISSING: - e_tmp[:, :] = e[l + 1, MISSING_INDEX] + if s[0, l + 1] == core.MISSING: + e_tmp[:, :] = e[l + 1, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -479,18 +401,18 @@ def backward_ls_dip_loop(n, m, G, s, e, c, r): if s[0, l + 1] == 1: # OBS is het if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, BOTH_HET] + e_tmp[j1, j2] = e[l + 1, core.BOTH_HET] else: # REF is hom - e_tmp[j1, j2] = e[l + 1, REF_HOM_OBS_HET] + e_tmp[j1, j2] = e[l + 1, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, REF_HET_OBS_HOM] + e_tmp[j1, j2] = e[l + 1, core.REF_HET_OBS_HOM] else: # REF is hom if G[l + 1, j1, j2] == s[0, l + 1]: # Equal - e_tmp[j1, j2] = e[l + 1, EQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.EQUAL_BOTH_HOM] else: # Unequal - e_tmp[j1, j2] = e[l + 1, UNEQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.UNEQUAL_BOTH_HOM] for j1 in range(n): for j2 in range(n): diff --git a/lshmm/forward_backward/fb_haploid.py b/lshmm/fb_haploid.py similarity index 53% rename from lshmm/forward_backward/fb_haploid.py rename to lshmm/fb_haploid.py index 69d01fc..2541e02 100644 --- a/lshmm/forward_backward/fb_haploid.py +++ b/lshmm/fb_haploid.py @@ -1,25 +1,27 @@ -"""Collection of functions to run forwards and backwards algorithms on haploid genotype data, where the data is structured as variants x samples.""" +""" +Various implementations of the Li & Stephens forwards-backwards algorithm on haploid genotype data, +where the data is structured as variants x samples. +""" import numpy as np +from lshmm import core from lshmm import jit -MISSING = -1 - @jit.numba_njit def forwards_ls_hap(n, m, H, s, e, r, norm=True): - """Matrix based haploid LS forward algorithm using numpy vectorisation.""" - # Initialise + """A matrix-based implementation using Numpy vectorisation.""" F = np.zeros((m, n)) r_n = r / n if norm: c = np.zeros(m) for i in range(n): - F[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_index = core.get_index_in_emission_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] ) + F[0, i] = 1 / n * e[0, emission_index] c[0] += F[0, i] for i in range(n): @@ -29,9 +31,10 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l] - F[l, i] *= e[ - l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING) - ] + emission_index = core.get_index_in_emission_matrix( + ref_allele=H[l, i], query_allele=s[0, l] + ) + F[l, i] *= e[l, emission_index] c[l] += F[l, i] for i in range(n): @@ -43,17 +46,19 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): c = np.ones(m) for i in range(n): - F[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_index = core.get_index_in_emission_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] ) + F[0, i] = 1 / n * e[0, emission_index] # Forwards pass for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l] - F[l, i] *= e[ - l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING) - ] + emission_index = core.get_index_in_emission_matrix( + ref_allele=H[l, i], query_allele=s[0, l] + ) + F[l, i] *= e[l, emission_index] ll = np.log10(np.sum(F[m - 1, :])) @@ -62,8 +67,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): @jit.numba_njit def backwards_ls_hap(n, m, H, s, e, c, r): - """Matrix based haploid LS backward algorithm using numpy vectorisation.""" - # Initialise + """A matrix-based implementation using Numpy vectorisation.""" B = np.zeros((m, n)) for i in range(n): B[m - 1, i] = 1 @@ -74,15 +78,10 @@ def backwards_ls_hap(n, m, H, s, e, c, r): tmp_B = np.zeros(n) tmp_B_sum = 0 for i in range(n): - tmp_B[i] = ( - e[ - l + 1, - np.int64( - np.equal(H[l + 1, i], s[0, l + 1]) or s[0, l + 1] == MISSING - ), - ] - * B[l + 1, i] + emission_index = core.get_index_in_emission_matrix( + ref_allele=H[l + 1, i], query_allele=s[0, l + 1] ) + tmp_B[i] = e[l + 1, emission_index] * B[l + 1, i] tmp_B_sum += tmp_B[i] for i in range(n): B[l, i] = r_n[l + 1] * tmp_B_sum diff --git a/lshmm/forward_backward/__init__.py b/lshmm/forward_backward/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/lshmm/vit_diploid.py b/lshmm/vit_diploid.py index 316b5d0..f28567d 100644 --- a/lshmm/vit_diploid.py +++ b/lshmm/vit_diploid.py @@ -1,82 +1,39 @@ -"""Collection of functions to run Viterbi algorithms on dipoid genotype data, where the data is structured as variants x samples.""" +""" +Various implementations of the Li & Stephens Viterbi algorithm on diploid genotype data, +where the data is structured as variants x samples x samples. +""" import numpy as np +from . import core from . import jit -MISSING = -1 -MISSING_INDEX = 3 - - -# https://github.com/numba/numba/issues/1269 -@jit.numba_njit -def np_apply_along_axis(func1d, axis, arr): - """Create numpy-like functions for max, sum etc.""" - assert arr.ndim == 2 - assert axis in [0, 1] - if axis == 0: - result = np.empty(arr.shape[1]) - for i in range(len(result)): - result[i] = func1d(arr[:, i]) - else: - result = np.empty(arr.shape[0]) - for i in range(len(result)): - result[i] = func1d(arr[i, :]) - return result - - -@jit.numba_njit -def np_amax(array, axis): - """Numba implementation of numpy vectorised maximum.""" - return np_apply_along_axis(np.amax, axis, array) - - -@jit.numba_njit -def np_sum(array, axis): - """Numba implementation of numpy vectorised sum.""" - return np_apply_along_axis(np.sum, axis, array) - - -@jit.numba_njit -def np_argmax(array, axis): - """Numba implementation of numpy vectorised argmax.""" - return np_apply_along_axis(np.argmax, axis, array) - @jit.numba_njit def forwards_viterbi_dip_naive(n, m, G, s, e, r): - """Naive implementation of LS diploid Viterbi algorithm.""" + """A naive implementation.""" # Initialise V = np.zeros((m, n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) r_n = r / n for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V[0, j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V[0, j1, j2] = 1 / (n**2) * e[0, emission_index] for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) for j1 in range(n): for j2 in range(n): - # Get the vector to maximise over v = np.zeros((n, n)) for k1 in range(n): for k2 in range(n): @@ -89,7 +46,7 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r): v[k1, k2] *= r_n[l] * (1 - r[l]) + r_n[l] ** 2 else: v[k1, k2] *= r_n[l] ** 2 - V[l, j1, j2] = np.amax(v) * e[l, index[j1, j2]] + V[l, j1, j2] = np.amax(v) * e[l, emission_index[j1, j2]] P[l, j1, j2] = np.argmax(v) c[l] = np.amax(V[l, :, :]) V[l, :, :] *= 1 / c[l] @@ -101,44 +58,37 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): - """Naive implementation of LS diploid Viterbi algorithm, with reduced memory.""" + """A naive implementation with reduced memory.""" # Initialise V = np.zeros((n, n)) - V_previous = np.zeros((n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + V_prev = np.zeros((n, n)) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) r_n = r / n for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V_prev[j1, j2] = 1 / (n**2) * e[0, emission_index] - # Take a look at Haploid Viterbi implementation in Jeromes code and see if we can pinch some ideas. + # Take a look at the haploid Viterbi implementation in Jerome's code, and + # see if we can pinch some ideas. # Diploid Viterbi, with smaller memory footprint. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) + for j1 in range(n): for j2 in range(n): - # Get the vector to maximise over v = np.zeros((n, n)) for k1 in range(n): for k2 in range(n): - v[k1, k2] = V_previous[k1, k2] + v[k1, k2] = V_prev[k1, k2] if (k1 == j1) and (k2 == j2): v[k1, k2] *= ( (1 - r[l]) ** 2 + 2 * (1 - r[l]) * r_n[l] + r_n[l] ** 2 @@ -147,10 +97,10 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): v[k1, k2] *= r_n[l] * (1 - r[l]) + r_n[l] ** 2 else: v[k1, k2] *= r_n[l] ** 2 - V[j1, j2] = np.amax(v) * e[l, index[j1, j2]] + V[j1, j2] = np.amax(v) * e[l, emission_index[j1, j2]] P[l, j1, j2] = np.argmax(v) c[l] = np.amax(V) - V_previous = np.copy(V) / c[l] + V_prev = np.copy(V) / c[l] ll = np.sum(np.log10(c)) @@ -159,43 +109,35 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): - """LS diploid Viterbi algorithm, with reduced memory.""" + """An implementation with reduced memory.""" # Initialise V = np.zeros((n, n)) - V_previous = np.zeros((n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + V_prev = np.zeros((n, n)) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) r_n = r / n for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V_prev[j1, j2] = 1 / (n**2) * e[0, emission_index] # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) - c[l] = np.amax(V_previous) - argmax = np.argmax(V_previous) + c[l] = np.amax(V_prev) + argmax = np.argmax(V_prev) - V_previous *= 1 / c[l] - V_rowcol_max = np_amax(V_previous, 0) - arg_rowcol_max = np_argmax(V_previous, 0) + V_prev *= 1 / c[l] + V_rowcol_max = core.np_amax(V_prev, 0) + arg_rowcol_max = core.np_argmax(V_prev, 0) no_switch = (1 - r[l]) ** 2 + 2 * (r_n[l] * (1 - r[l])) + r_n[l] ** 2 single_switch = r_n[l] * (1 - r[l]) + r_n[l] ** 2 @@ -215,7 +157,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): else: template_single_switch = arg_rowcol_max[j2] * n + j2 - V[j1, j2] = V_previous[j1, j2] * no_switch # No switch in either + V[j1, j2] = V_prev[j1, j2] * no_switch # No switch in either P[l, j1, j2] = j1_j2 # Single or double switch? @@ -231,9 +173,9 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): V[j1, j2] = double_switch P[l, j1, j2] = argmax - V[j1, j2] *= e[l, index[j1, j2]] + V[j1, j2] *= e[l, emission_index[j1, j2]] j1_j2 += 1 - V_previous = np.copy(V) + V_prev = np.copy(V) ll = np.sum(np.log10(c)) + np.log10(np.amax(V)) @@ -242,10 +184,10 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): - """LS diploid Viterbi algorithm, with reduced memory.""" + """An implementation with reduced memory and no pointer.""" # Initialise V = np.zeros((n, n)) - V_previous = np.zeros((n, n)) + V_prev = np.zeros((n, n)) c = np.ones(m) r_n = r / n @@ -262,35 +204,27 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V_prev[j1, j2] = 1 / (n**2) * e[0, emission_index] # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) - c[l] = np.amax(V_previous) - argmax = np.argmax(V_previous) + c[l] = np.amax(V_prev) + argmax = np.argmax(V_prev) V_argmaxes[l - 1] = argmax # added - V_previous *= 1 / c[l] - V_rowcol_max = np_amax(V_previous, 0) + V_prev *= 1 / c[l] + V_rowcol_max = core.np_amax(V_prev, 0) V_rowcol_maxes[l - 1, :] = V_rowcol_max - arg_rowcol_max = np_argmax(V_previous, 0) + arg_rowcol_max = core.np_argmax(V_prev, 0) V_rowcol_argmaxes[l - 1, :] = arg_rowcol_max no_switch = (1 - r[l]) ** 2 + 2 * (r_n[l] * (1 - r[l])) + r_n[l] ** 2 @@ -302,7 +236,7 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): V_single_switch = max(V_rowcol_max[j1], V_rowcol_max[j2]) - V[j1, j2] = V_previous[j1, j2] * no_switch # No switch in either + V[j1, j2] = V_prev[j1, j2] * no_switch # No switch in either # Single or double switch? single_switch_tmp = single_switch * V_single_switch @@ -317,13 +251,13 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): V[j1, j2] = double_switch recombs_double[l] = np.append(recombs_double[l], values=j1_j2) - V[j1, j2] *= e[l, index[j1, j2]] + V[j1, j2] *= e[l, emission_index[j1, j2]] j1_j2 += 1 - V_previous = np.copy(V) + V_prev = np.copy(V) - V_argmaxes[m - 1] = np.argmax(V_previous) - V_rowcol_maxes[m - 1, :] = np_amax(V_previous, 0) - V_rowcol_argmaxes[m - 1, :] = np_argmax(V_previous, 0) + V_argmaxes[m - 1] = np.argmax(V_prev) + V_rowcol_maxes[m - 1, :] = core.np_amax(V_prev, 0) + V_rowcol_argmaxes[m - 1, :] = core.np_argmax(V_prev, 0) ll = np.sum(np.log10(c)) + np.log10(np.amax(V)) return ( @@ -339,35 +273,27 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): - """Vectorised LS diploid Viterbi algorithm using numpy.""" + """An implementation using Numpy vectorisation.""" # Initialise V = np.zeros((m, n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) r_n = r / n for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V[0, j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V[0, j1, j2] = 1 / (n**2) * e[0, emission_index] # Jumped the gun - vectorising. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) for j1 in range(n): for j2 in range(n): @@ -376,7 +302,7 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): v[j1, :] += r_n[l] * (1 - r[l]) v[:, j2] += r_n[l] * (1 - r[l]) v *= V[l - 1, :, :] - V[l, j1, j2] = np.amax(v) * e[l, index[j1, j2]] + V[l, j1, j2] = np.amax(v) * e[l, emission_index[j1, j2]] P[l, j1, j2] = np.argmax(v) c[l] = np.amax(V[l, :, :]) @@ -388,7 +314,7 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): - """Fully vectorised naive LS diploid Viterbi algorithm using numpy.""" + """Fully vectorised naive implementation using Numpy.""" char_both = np.eye(n * n).ravel().reshape((n, n, n, n)) char_col = np.tile(np.sum(np.eye(n * n).reshape((n, n, n, n)), 3), (n, 1, 1, 1)) char_row = np.copy(char_col).T @@ -396,28 +322,23 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): # Initialise V = np.zeros((m, n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) - if s[0, 0] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[0, :, :], s[0, 0]).astype(np.int64) - + 2 * (G[0, :, :] == 1).astype(np.int64) - + np.int64(s[0, 0] == 1) - ) - V[0, :, :] = 1 / (n**2) * e[0, index] + + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[0, :, :], + query_allele=s[0, 0], + n=n, + ) + V[0, :, :] = 1 / (n**2) * e[0, emission_index] r_n = r / n for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) - else: - index = ( - 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) - + 2 * (G[l, :, :] == 1).astype(np.int64) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_matrix_diploid_G( + ref_G=G[l, :, :], + query_allele=s[0, l], + n=n, + ) v = ( (r_n[l] ** 2) + (1 - r[l]) ** 2 * char_both @@ -425,7 +346,7 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): ) v *= V[l - 1, :, :] P[l, :, :] = np.argmax(v.reshape(n, n, -1), 2) # Have to flatten to use argmax - V[l, :, :] = v.reshape(n, n, -1)[rows, cols, P[l, :, :]] * e[l, index] + V[l, :, :] = v.reshape(n, n, -1)[rows, cols, P[l, :, :]] * e[l, emission_index] c[l] = np.amax(V[l, :, :]) V[l, :, :] *= 1 / c[l] @@ -439,8 +360,9 @@ def backwards_viterbi_dip(m, V_last, P): """Run a backwards pass to determine the most likely path.""" assert V_last.ndim == 2 assert V_last.shape[0] == V_last.shape[1] - # Initialisation - path = np.zeros(m).astype(np.int64) + + # Initialise + path = np.zeros(m, dtype=np.int64) path[m - 1] = np.argmax(V_last) # Backtrace @@ -455,8 +377,7 @@ def in_list(array, value): where = np.searchsorted(array, value) if where < array.shape[0]: return array[where] == value - else: - return False + return False @jit.numba_njit @@ -472,8 +393,9 @@ def backwards_viterbi_dip_no_pointer( """Run a backwards pass to determine the most likely path.""" assert V_last.ndim == 2 assert V_last.shape[0] == V_last.shape[1] - # Initialisation - path = np.zeros(m).astype(np.int64) + + # Initialise + path = np.zeros(m, dtype=np.int64) path[m - 1] = np.argmax(V_last) n = V_last.shape[0] @@ -496,37 +418,24 @@ def backwards_viterbi_dip_no_pointer( def get_phased_path(n, path): - """Obtain the phased path.""" return np.unravel_index(path, (n, n)) @jit.numba_njit def path_ll_dip(n, m, G, phased_path, s, e, r): """Evaluate log-likelihood path through a reference panel which results in sequence s.""" - if s[0, 0] == MISSING: - index = MISSING_INDEX - else: - index = ( - 4 * np.int64(np.equal(G[0, phased_path[0][0], phased_path[1][0]], s[0, 0])) - + 2 * np.int64(G[0, phased_path[0][0], phased_path[1][0]] == 1) - + np.int64(s[0, 0] == 1) - ) - log_prob_path = np.log10(1 / (n**2) * e[0, index]) + emission_index = core.get_index_in_emission_matrix_diploid( + ref_allele=G[0, phased_path[0][0], phased_path[1][0]], query_allele=s[0, 0] + ) + log_prob_path = np.log10(1 / (n**2) * e[0, emission_index]) old_phase = np.array([phased_path[0][0], phased_path[1][0]]) r_n = r / n for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX - else: - index = ( - 4 - * np.int64( - np.equal(G[l, phased_path[0][l], phased_path[1][l]], s[0, l]) - ) - + 2 * np.int64(G[l, phased_path[0][l], phased_path[1][l]] == 1) - + np.int64(s[0, l] == 1) - ) + emission_index = core.get_index_in_emission_matrix_diploid( + ref_allele=G[l, phased_path[0][l], phased_path[1][l]], + query_allele=s[0, l], + ) current_phase = np.array([phased_path[0][l], phased_path[1][l]]) phase_diff = np.sum(~np.equal(current_phase, old_phase)) @@ -540,7 +449,7 @@ def path_ll_dip(n, m, G, phased_path, s, e, r): else: log_prob_path += np.log10(r_n[l] ** 2) - log_prob_path += np.log10(e[l, index]) + log_prob_path += np.log10(e[l, emission_index]) old_phase = current_phase return log_prob_path diff --git a/lshmm/vit_haploid.py b/lshmm/vit_haploid.py index 7fec45e..b296516 100644 --- a/lshmm/vit_haploid.py +++ b/lshmm/vit_haploid.py @@ -1,57 +1,60 @@ -"""Collection of functions to run Viterbi algorithms on haploid genotype data, where the data is structured as variants x samples.""" +""" +Various implementations of the Li & Stephens Viterbi algorithm on haploid genotype data, +where the data is structured as variants x samples. +""" import numpy as np +from . import core from . import jit -MISSING = -1 - @jit.numba_njit def viterbi_naive_init(n, m, H, s, e, r): - """Initialise naive implementation of LS viterbi.""" + """Initialise a naive implementation.""" V = np.zeros((m, n)) - P = np.zeros((m, n)).astype(np.int64) + P = np.zeros((m, n), dtype=np.int64) r_n = r / n + for i in range(n): - V[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] ) + V[0, i] = 1 / n * e[0, emission_idx] return V, P, r_n @jit.numba_njit def viterbi_init(n, m, H, s, e, r): - """Initialise naive, but more space memory efficient implementation of LS viterbi.""" - V_previous = np.zeros(n) + """Initialise a naive, but more memory efficient, implementation.""" + V_prev = np.zeros(n) V = np.zeros(n) - P = np.zeros((m, n)).astype(np.int64) + P = np.zeros((m, n), dtype=np.int64) r_n = r / n for i in range(n): - V_previous[i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] ) + V_prev[i] = 1 / n * e[0, emission_idx] - return V, V_previous, P, r_n + return V, V_prev, P, r_n @jit.numba_njit def forwards_viterbi_hap_naive(n, m, H, s, e, r): - """Naive implementation of LS haploid Viterbi algorithm.""" - # Initialise + """A naive implementation of the forward pass.""" V, P, r_n = viterbi_naive_init(n, m, H, s, e, r) for j in range(1, m): for i in range(n): - # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V[j - 1, k] + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[j, i], query_allele=s[0, j] ) + v[k] = V[j - 1, k] * e[j, emission_idx] if k == i: v[k] *= 1 - r[j] + r_n[j] else: @@ -66,8 +69,7 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): - """Naive matrix based implementation of LS haploid forward Viterbi algorithm using numpy.""" - # Initialise + """A naive matrix-based implementation of the forward pass using Numpy.""" V, P, r_n = viterbi_naive_init(n, m, H, s, e, r) for j in range(1, m): @@ -75,7 +77,10 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): for i in range(n): v = np.copy(v_tmp) v[i] += V[j - 1, i] * (1 - r[j]) - v *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[j, i], query_allele=s[0, j] + ) + v *= e[j, emission_idx] P[j, i] = np.argmax(v) V[j, i] = v[P[j, i]] @@ -86,26 +91,24 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): - """Naive implementation of LS haploid Viterbi algorithm, with reduced memory.""" - # Initialise - V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r) + """A naive implementation of the forward pass with reduced memory.""" + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r) for j in range(1, m): for i in range(n): - # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V_previous[k] + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[j, i], query_allele=s[0, j] ) + v[k] = V_prev[k] * e[j, emission_idx] if k == i: v[k] *= 1 - r[j] + r_n[j] else: v[k] *= r_n[j] P[j, i] = np.argmax(v) V[i] = v[P[j, i]] - V_previous = np.copy(V) + V_prev = np.copy(V) ll = np.log10(np.amax(V)) @@ -114,30 +117,27 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): - """Naive implementation of LS haploid Viterbi algorithm, with reduced memory and rescaling.""" - # Initialise - V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r) + """A naive implementation of the forward pass with reduced memory and rescaling.""" + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r) c = np.ones(m) for j in range(1, m): - c[j] = np.amax(V_previous) - V_previous *= 1 / c[j] + c[j] = np.amax(V_prev) + V_prev *= 1 / c[j] for i in range(n): - # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V_previous[k] + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[j, i], query_allele=s[0, j] ) + v[k] = V_prev[k] * e[j, emission_idx] if k == i: v[k] *= 1 - r[j] + r_n[j] else: v[k] *= r_n[j] P[j, i] = np.argmax(v) V[i] = v[P[j, i]] - - V_previous = np.copy(V) + V_prev = np.copy(V) ll = np.sum(np.log10(c)) + np.log10(np.amax(V)) @@ -146,24 +146,26 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): - """LS haploid Viterbi algorithm, with reduced memory and exploits the Markov process structure.""" - # Initialise - V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r) + """An implementation with reduced memory that exploits the Markov structure.""" + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r) c = np.ones(m) for j in range(1, m): - argmax = np.argmax(V_previous) - c[j] = V_previous[argmax] - V_previous *= 1 / c[j] + argmax = np.argmax(V_prev) + c[j] = V_prev[argmax] + V_prev *= 1 / c[j] V = np.zeros(n) for i in range(n): - V[i] = V_previous[i] * (1 - r[j] + r_n[j]) + V[i] = V_prev[i] * (1 - r[j] + r_n[j]) P[j, i] = i if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - V_previous = np.copy(V) + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[j, i], query_allele=s[0, j] + ) + V[i] *= e[j, emission_idx] + V_prev = np.copy(V) ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -172,12 +174,14 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): - """LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure.""" - # Initialise + """An implementation with even smaller memory footprint that exploits the Markov structure.""" V = np.zeros(n) for i in range(n): - V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - P = np.zeros((m, n)).astype(np.int64) + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] + ) + V[i] = 1 / n * e[0, emission_idx] + P = np.zeros((m, n), dtype=np.int64) r_n = r / n c = np.ones(m) @@ -191,7 +195,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[j, i], query_allele=s[0, j] + ) + V[i] *= e[j, emission_idx] ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -200,16 +207,21 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): - """LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure.""" - # Initialise + """ + An implementation with even smaller memory footprint and rescaling + that exploits the Markov structure. + """ V = np.zeros(n) for i in range(n): - V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] + ) + V[i] = 1 / n * e[0, emission_idx] r_n = r / n c = np.ones(m) - recombs = [ - np.zeros(shape=0, dtype=np.int64) for _ in range(m) - ] # This is going to be filled with the templates we can recombine to that have higher prob than staying where we are. + # This is going to be filled with the templates we can recombine to + # that have higher prob than staying where we are. + recombs = [np.zeros(shape=0, dtype=np.int64) for _ in range(m)] V_argmaxes = np.zeros(m) @@ -225,7 +237,8 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): recombs[j] = np.append( recombs[j], i ) # We add template i as a potential template to recombine to at site j. - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_matrix(H[j, i], s[0, j]) + V[i] *= e[j, emission_idx] V_argmaxes[m - 1] = np.argmax(V) ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -237,9 +250,8 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): @jit.numba_njit def backwards_viterbi_hap(m, V_last, P): """Run a backwards pass to determine the most likely path.""" - # Initialise assert len(V_last.shape) == 1 - path = np.zeros(m).astype(np.int64) + path = np.zeros(m, dtype=np.int64) path[m - 1] = np.argmax(V_last) for j in range(m - 2, -1, -1): @@ -251,8 +263,7 @@ def backwards_viterbi_hap(m, V_last, P): @jit.numba_njit def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs): """Run a backwards pass to determine the most likely path.""" - # Initialise - path = np.zeros(m).astype(np.int64) + path = np.zeros(m, dtype=np.int64) path[m - 1] = V_argmaxes[m - 1] for j in range(m - 2, -1, -1): @@ -266,14 +277,18 @@ def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs): @jit.numba_njit def path_ll_hap(n, m, H, path, s, e, r): - """Evaluate log-likelihood path through a reference panel which results in sequence s.""" - index = np.int64(np.equal(H[0, path[0]], s[0, 0]) or s[0, 0] == MISSING) - log_prob_path = np.log10((1 / n) * e[0, index]) + """Evaluate the log-likelihood of a path through a reference panel resulting in a sequence.""" + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[0, path[0]], query_allele=s[0, 0] + ) + log_prob_path = np.log10((1 / n) * e[0, emission_idx]) old = path[0] r_n = r / n for l in range(1, m): - index = np.int64(np.equal(H[l, path[l]], s[0, l]) or s[0, l] == MISSING) + emission_idx = core.get_index_in_emission_matrix( + ref_allele=H[l, path[l]], query_allele=s[0, l] + ) current = path[l] same = old == current @@ -282,7 +297,7 @@ def path_ll_hap(n, m, H, path, s, e, r): else: log_prob_path += np.log10(r_n[l]) - log_prob_path += np.log10(e[l, index]) + log_prob_path += np.log10(e[l, emission_idx]) old = current return log_prob_path diff --git a/tests/lsbase.py b/tests/lsbase.py new file mode 100644 index 0000000..0922f81 --- /dev/null +++ b/tests/lsbase.py @@ -0,0 +1,304 @@ +import itertools + +import numpy as np + +import msprime + +import lshmm.core as core + + +class LSBase: + """Base class of tests for Li & Stephens HMM algorithms.""" + + def verify(self, ts): + raise NotImplementedError() + + def assertAllClose(self, A, B): + np.testing.assert_allclose(A, B, rtol=1e-9, atol=0.0) + + # Helper routine + def get_num_alleles(self, ref_haps, query): + assert ref_haps.shape[0] == query.shape[1] + num_sites = ref_haps.shape[0] + num_alleles = np.zeros(num_sites, dtype=np.int8) + exclusion_set = np.array([core.MISSING]) + for i in range(num_sites): + uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) + num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) + assert np.all(num_alleles >= 0), "Number of alleles cannot be zero." + return num_alleles + + # Haploid + def get_examples_haploid(self, ts): + H = ts.genotype_matrix() + s = H[:, 0].reshape(1, H.shape[0]) + H = H[:, 1:] + haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] + s_miss_last = s.copy() + s_miss_last[0, -1] = core.MISSING + s_miss_mid = s.copy() + s_miss_mid[0, ts.num_sites // 2] = core.MISSING + s_miss_all = s.copy() + s_miss_all[0, :] = core.MISSING + haplotypes.append(s_miss_last) + haplotypes.append(s_miss_mid) + haplotypes.append(s_miss_all) + return H, haplotypes + + def get_emission_matrix_haploid( + self, mu, m, n_alleles, scale_mutation_based_on_n_alleles + ): + e = np.zeros((m, 2)) + if isinstance(mu, float): + mu = mu * np.ones(m) + if scale_mutation_based_on_n_alleles: + e[:, 0] = mu - mu * np.equal( + n_alleles, np.ones(m) + ) # Add boolean in case we're at an invariant site + e[:, 1] = 1 - (n_alleles - 1) * mu + else: + for j in range(m): + if n_alleles[j] == 1: + # In case we're at an invariant site + e[j, 0] = 0 + e[j, 1] = 1 + else: + e[j, 0] = mu[j] / (n_alleles[j] - 1) + e[j, 1] = 1 - mu[j] + return e + + def get_examples_pars_haploid( + self, ts, mean_r=None, mean_mu=None, scale_mutation=True, seed=42 + ): + """Returns an iterator over combinations of examples and parameters.""" + np.random.seed(seed) + H, haplotypes = self.get_examples_haploid(ts) + m = ts.num_sites + n = H.shape[1] + if mean_r is not None and mean_mu is not None: + rs = [mean_r * (np.random.rand(m) + 0.5) / 2] + mus = [mean_mu * (np.random.rand(m) + 0.5) / 2] + else: + rs = [ + np.zeros(m) + 0.01, # Equal recombination and mutation + np.zeros(m) + 0.999, # Extreme + np.zeros(m) + 1e-6, # Extreme + np.random.rand(m), # Random + ] + mus = [ + np.zeros(m) + 0.01, # Equal recombination and mutation + np.zeros(m) + 0.2, # Extreme + np.zeros(m) + 1e-6, # Extreme + np.random.rand(m) * 0.2, # Random + ] + for s, r, mu in itertools.product(haplotypes, rs, mus): + r[0] = 0 + # Must be calculated from the genotype matrix, + # because we can now get back mutations that + # result in the number of alleles being higher + # than the number of alleles in the reference panel. + n_alleles = self.get_num_alleles(H, s) + e = self.get_emission_matrix_haploid( + mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation + ) + yield n, m, H, s, e, r, mu + + # Diploid + def get_examples_diploid(self, ts): + H = ts.genotype_matrix() + s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) + H = H[:, 2:] + genotypes = [ + s, + H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), + ] + s_miss_last = s.copy() + s_miss_last[0, -1] = core.MISSING + s_miss_mid = s.copy() + s_miss_mid[0, ts.num_sites // 2] = core.MISSING + s_miss_all = s.copy() + s_miss_all[0, :] = core.MISSING + # FIXME Handle MISSING properly. + # genotypes.append(s_miss_last) + # genotypes.append(s_miss_mid) + # genotypes.append(s_miss_all) + m = ts.num_sites + n = H.shape[1] + G = np.zeros((m, n, n)) + for i in range(m): + G[i, :, :] = np.add.outer(H[i, :], H[i, :]) + return H, G, genotypes + + def get_emission_matrix_diploid(self, mu, m): + e = np.zeros((m, 8)) + e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2 + e[:, core.UNEQUAL_BOTH_HOM] = mu**2 + e[:, core.BOTH_HET] = (1 - mu) ** 2 + mu**2 + e[:, core.REF_HOM_OBS_HET] = 2 * mu * (1 - mu) + e[:, core.REF_HET_OBS_HOM] = mu * (1 - mu) + e[:, core.MISSING_INDEX] = 1 + return e + + def get_examples_pars_diploid(self, ts, mean_r=None, mean_mu=None, seed=42): + """Returns an iterator over combinations of examples and parameters.""" + np.random.seed(seed) + H, G, genotypes = self.get_examples_diploid(ts) + m = ts.num_sites + n = H.shape[1] + if mean_r is not None and mean_mu is not None: + rs = [mean_r * (np.random.rand(m) + 0.5) / 2] + mus = [mean_mu * (np.random.rand(m) + 0.5) / 2] + else: + rs = [ + np.zeros(m) + 0.01, # Equal recombination and mutation + np.zeros(m) + 0.999, # Extreme + np.zeros(m) + 1e-6, # Extreme + np.random.rand(m), # Random + ] + mus = [ + np.zeros(m) + 0.01, # Equal recombination and mutation + np.zeros(m) + 0.33, # Extreme + np.zeros(m) + 1e-6, # Extreme + np.random.rand(m) * 0.33, # Random + ] + for s, r, mu in itertools.product(genotypes, rs, mus): + r[0] = 0 + e = self.get_emission_matrix_diploid(mu, m) + yield n, m, G, s, e, r, mu + + # Prepare simple example datasets. + def get_simple_n10_no_recombination(self, seed=42): + ts = msprime.simulate( + 10, + recombination_rate=0, + mutation_rate=0.5, + random_seed=seed, + ) + assert ts.num_sites > 3 + return ts + + def get_simple_n6(self, seed=42): + ts = msprime.simulate( + 6, + recombination_rate=2, + mutation_rate=7, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + def get_simple_n8(self, seed=42): + ts = msprime.simulate( + 8, + recombination_rate=2, + mutation_rate=5, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + def get_simple_n8_high_recombination(self, seed=42): + ts = msprime.simulate( + 8, + recombination_rate=20, + mutation_rate=5, + random_seed=seed, + ) + assert ts.num_trees > 15 + assert ts.num_sites > 5 + return ts + + def get_simple_n16(self, seed=42): + ts = msprime.simulate( + 16, + recombination_rate=2, + mutation_rate=5, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + # Prepare example datasets with multiallelic sites. + def get_multiallelic_n10_no_recombination(self, seed=42): + ts = msprime.sim_ancestry( + samples=10, + recombination_rate=0, + sequence_length=10, + population_size=1e4, + random_seed=seed, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-5, + random_seed=seed, + ) + assert ts.num_sites > 3 + return ts + + def get_multiallelic_n6(self, seed=42): + ts = msprime.sim_ancestry( + samples=6, + recombination_rate=1e-4, + sequence_length=40, + population_size=1e4, + random_seed=seed, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-3, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + def get_multiallelic_n8(self, seed=42): + ts = msprime.sim_ancestry( + samples=8, + recombination_rate=1e-4, + sequence_length=20, + population_size=1e4, + random_seed=seed, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-4, + random_seed=seed, + ) + assert ts.num_sites > 5 + assert ts.num_trees > 15 + return ts + + def get_multiallelic_n16(self, seed=42): + ts = msprime.sim_ancestry( + samples=16, + recombination_rate=1e-2, + sequence_length=20, + population_size=1e4, + random_seed=seed, + ) + ts = msprime.sim_mutations( + ts, + rate=1e-4, + random_seed=seed, + ) + assert ts.num_sites > 5 + return ts + + # Prepare a larger example dataset. + def get_larger(self, num_samples, seq_length, mean_r, mean_mu, seed=42): + ts = msprime.simulate( + num_samples + 1, + length=seq_length, + mutation_rate=mean_mu, + recombination_rate=mean_r, + random_seed=seed, + ) + return ts + + +class ForwardBackwardAlgorithmBase(LSBase): + """Base for testing forwards-backwards algorithms.""" + + +class ViterbiAlgorithmBase(LSBase): + """Base for testing Viterbi algoritms.""" diff --git a/tests/test_API.py b/tests/test_API.py index 129e67a..fbc6f93 100644 --- a/tests/test_API.py +++ b/tests/test_API.py @@ -1,249 +1,14 @@ -# Simulation -import itertools - -# Python libraries -import msprime -import numpy as np -import pytest -import tskit - +from . import lsbase import lshmm as ls -import lshmm.forward_backward.fb_diploid as fbd -import lshmm.forward_backward.fb_haploid as fbh +import lshmm.fb_diploid as fbd +import lshmm.fb_haploid as fbh import lshmm.vit_diploid as vd import lshmm.vit_haploid as vh -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - - -class LSBase: - """Superclass of Li and Stephens tests.""" - - def example_haplotypes(self, ts, seed=42): - H = ts.genotype_matrix() - s = H[:, 0].reshape(1, H.shape[0]) - H = H[:, 1:] - - haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] - s_tmp = s.copy() - s_tmp[0, -1] = MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, :] = MISSING - haplotypes.append(s_tmp) - - return H, haplotypes - - def haplotype_emission(self, mu, m, n_alleles, scale_mutation_based_on_n_alleles): - # Define the emission probability matrix - e = np.zeros((m, 2)) - if isinstance(mu, float): - mu = mu * np.ones(m) - - if scale_mutation_based_on_n_alleles: - e[:, 0] = mu - mu * np.equal( - n_alleles, np.ones(m) - ) # Added boolean in case we're at an invariant site - e[:, 1] = 1 - (n_alleles - 1) * mu - else: - for j in range(m): - if n_alleles[j] == 1: # In case we're at an invariant site - e[j, 0] = 0 - e[j, 1] = 1 - else: - e[j, 0] = mu[j] / (n_alleles[j] - 1) - e[j, 1] = 1 - mu[j] - return e - - def genotype_emission(self, mu, m): - # Define the emission probability matrix - e = np.zeros((m, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2 - e[:, UNEQUAL_BOTH_HOM] = mu**2 - e[:, BOTH_HET] = (1 - mu) ** 2 + mu**2 - e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu) - e[:, REF_HET_OBS_HOM] = mu * (1 - mu) - e[:, MISSING_INDEX] = 1 - - return e - - def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): - """Returns an iterator over combinations of haplotype, recombination and - mutation probabilities.""" - np.random.seed(seed) - H, haplotypes = self.example_haplotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - def _get_num_alleles(ref_haps, query): - assert ref_haps.shape[0] == query.shape[1] - num_sites = ref_haps.shape[0] - num_alleles = np.zeros(num_sites, dtype=np.int8) - exclusion_set = np.array([MISSING]) - for i in range(num_sites): - uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) - num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) - assert np.all(num_alleles >= 0), "Number of alleles cannot be zero." - return num_alleles - - # Here we have equal mutation and recombination - r = np.zeros(m) + 0.01 - mu = np.zeros(m) + 0.01 - r[0] = 0 - - for s in haplotypes: - n_alleles = _get_num_alleles(H, s) - e = self.haplotype_emission( - mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation - ) - yield n, m, H, s, e, r, mu - - # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - - for s, r, mu in itertools.product(haplotypes, rs, mus): - r[0] = 0 - n_alleles = _get_num_alleles(H, s) - e = self.haplotype_emission( - mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation - ) - yield n, m, H, s, e, r, mu - - def example_genotypes(self, ts, seed=42): - np.random.seed(seed) - H = ts.genotype_matrix() - s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) - H = H[:, 2:] - - genotypes = [ - s, - H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), - ] - - s_tmp = s.copy() - s_tmp[0, -1] = MISSING - genotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING - genotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, :] = MISSING - genotypes.append(s_tmp) - - m = ts.get_num_sites() - n = H.shape[1] - - G = np.zeros((m, n, n)) - for i in range(m): - G[i, :, :] = np.add.outer(H[i, :], H[i, :]) - - return H, G, genotypes - - def example_parameters_genotypes(self, ts, seed=42): - np.random.seed(seed) - H, G, genotypes = self.example_genotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - # Here we have equal mutation and recombination - r = np.zeros(m) + 0.01 - mu = np.zeros(m) + 0.01 - r[0] = 0 - - e = self.genotype_emission(mu, m) - - for s in genotypes: - yield n, m, G, s, e, r, mu - - # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - - e = self.genotype_emission(mu, m) - - for s, r, mu in itertools.product(genotypes, rs, mus): - r[0] = 0 - e = self.genotype_emission(mu, m) - yield n, m, G, s, e, r, mu - - def example_parameters_genotypes_larger( - self, ts, seed=42, mean_r=1e-5, mean_mu=1e-5 - ): - np.random.seed(seed) - H, G, genotypes = self.example_genotypes(ts) - - m = ts.get_num_sites() - n = H.shape[1] - - r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - r[0] = 0 - - # Error probability - mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - - # Define the emission probability matrix - e = self.genotype_emission(mu, m) - - for s in genotypes: - yield n, m, G, s, e, r, mu - - def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" - assert np.allclose(A, B, rtol=1e-9, atol=0.0) - - # Define a bunch of very small tree-sequences for testing a collection of parameters on - def test_simple_n_10_no_recombination(self): - ts = msprime.simulate( - 10, recombination_rate=0, mutation_rate=0.5, random_seed=42 - ) - assert ts.num_sites > 3 - self.verify(ts) - - def test_simple_n_6(self): - ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42) - assert ts.num_sites > 5 - self.verify(ts) - - def test_simple_n_8(self): - ts = msprime.simulate(8, recombination_rate=2, mutation_rate=5, random_seed=42) - assert ts.num_sites > 5 - self.verify(ts) - - def test_simple_n_8_high_recombination(self): - ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42) - assert ts.num_trees > 15 - assert ts.num_sites > 5 - self.verify(ts) - - def test_simple_n_16(self): - ts = msprime.simulate(16, recombination_rate=2, mutation_rate=5, random_seed=42) - assert ts.num_sites > 5 - self.verify(ts) - - def verify(self, ts): - raise NotImplementedError() - - -class FBAlgorithmBase(LSBase): - """Base for forwards backwards algorithm tests.""" - - -class TestMethodsHap(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" +class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts): - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts): F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r) B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) F, c, ll = ls.forwards(H_vs, s, r, p_mutation=mu) @@ -255,12 +20,30 @@ def verify(self, ts): F, c, ll = ls.forwards(H_vs, s, r, mu) B = ls.backwards(H_vs, s, c, r, mu) + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + self.verify(ts) + + def test_simple_n6(self): + ts = self.get_simple_n6() + self.verify(ts) + + def test_simple_n8(self): + ts = self.get_simple_n8() + self.verify(ts) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + self.verify(ts) + + def test_simple_n16(self): + ts = self.get_simple_n16() + self.verify(ts) -class TestMethodsDip(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" +class TestForwardBackwardDiploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts): - for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts): + for n, m, G_vs, s, e_vs, r, mu in self.get_examples_pars_diploid(ts): F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop( n, m, G_vs, s, e_vs, r, norm=True ) @@ -271,35 +54,85 @@ def verify(self, ts): self.assertAllClose(B, B_vs) self.assertAllClose(ll_vs, ll) + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + self.verify(ts) + + def test_simple_n6(self): + ts = self.get_simple_n6() + self.verify(ts) -class VitAlgorithmBase(LSBase): - """Base for viterbi algoritm tests.""" + def test_simple_n8(self): + ts = self.get_simple_n8() + self.verify(ts) + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + self.verify(ts) -class TestViterbiHap(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" + def test_simple_n16(self): + ts = self.get_simple_n16() + self.verify(ts) + +class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts): - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts): V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( n, m, H_vs, s, e_vs, r ) path_vs = vh.backwards_viterbi_hap(m, V_vs, P_vs) path, ll = ls.viterbi(H_vs, s, r, p_mutation=mu) - self.assertAllClose(ll_vs, ll) self.assertAllClose(path_vs, path) + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + self.verify(ts) + + def test_simple_n6(self): + ts = self.get_simple_n6() + self.verify(ts) + + def test_simple_n8(self): + ts = self.get_simple_n8() + self.verify(ts) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + self.verify(ts) + + def test_simple_n16(self): + ts = self.get_simple_n16() + self.verify(ts) -class TestViterbiDip(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" +class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts): - for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts): + for n, m, G_vs, s, e_vs, r, mu in self.get_examples_pars_diploid(ts): V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r) path_vs = vd.backwards_viterbi_dip(m, V_vs, P_vs) phased_path_vs = vd.get_phased_path(n, path_vs) path, ll = ls.viterbi(G_vs, s, r, p_mutation=mu) - self.assertAllClose(ll_vs, ll) self.assertAllClose(phased_path_vs, path) + + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + self.verify(ts) + + def test_simple_n6(self): + ts = self.get_simple_n6() + self.verify(ts) + + def test_simple_n8(self): + ts = self.get_simple_n8() + self.verify(ts) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + self.verify(ts) + + def test_simple_n_16(self): + ts = self.get_simple_n16() + self.verify(ts) diff --git a/tests/test_API_multiallelic.py b/tests/test_API_multiallelic.py index 92f1ab3..5a3abfa 100644 --- a/tests/test_API_multiallelic.py +++ b/tests/test_API_multiallelic.py @@ -1,196 +1,21 @@ -# Simulation -import itertools - -# Python libraries -import msprime -import numpy as np -import pytest -import tskit - +from . import lsbase import lshmm as ls -import lshmm.forward_backward.fb_diploid as fbd -import lshmm.forward_backward.fb_haploid as fbh -import lshmm.vit_diploid as vd +import lshmm.fb_haploid as fbh import lshmm.vit_haploid as vh -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - - -class LSBase: - """Superclass of Li and Stephens tests.""" - - def example_haplotypes(self, ts, num_random=10, seed=42): - H = ts.genotype_matrix() - s = H[:, 0].reshape(1, H.shape[0]) - H = H[:, 1:] - - haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] - s_tmp = s.copy() - s_tmp[0, -1] = MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, :] = MISSING - haplotypes.append(s_tmp) - - return H, haplotypes - - def haplotype_emission(self, mu, m, n_alleles, scale_mutation_based_on_n_alleles): - # Define the emission probability matrix - e = np.zeros((m, 2)) - if isinstance(mu, float): - mu = mu * np.ones(m) - - if scale_mutation_based_on_n_alleles: - e[:, 0] = mu - mu * np.equal( - n_alleles, np.ones(m) - ) # Added boolean in case we're at an invariant site - e[:, 1] = 1 - (n_alleles - 1) * mu - else: - for j in range(m): - if n_alleles[j] == 1: # In case we're at an invariant site - e[j, 0] = 0 - e[j, 1] = 1 - else: - e[j, 0] = mu[j] / (n_alleles[j] - 1) - e[j, 1] = 1 - mu[j] - return e - - def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): - """Returns an iterator over combinations of haplotype, recombination and - mutation probabilities.""" - np.random.seed(seed) - H, haplotypes = self.example_haplotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - def _get_num_alleles(ref_haps, query): - assert ref_haps.shape[0] == query.shape[1] - num_sites = ref_haps.shape[0] - num_alleles = np.zeros(num_sites, dtype=np.int8) - exclusion_set = np.array([MISSING]) - for i in range(num_sites): - uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) - num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) - assert np.all(num_alleles >= 0), "Number of alleles cannot be zero." - return num_alleles - - # Here we have equal mutation and recombination - r = np.zeros(m) + 0.01 - mu = np.zeros(m) + 0.01 - r[0] = 0 - - for s in haplotypes: - # Must be calculated from the genotype matrix because we can now get back mutations that - # result in the number of alleles being higher than the number of alleles in the reference panel. - n_alleles = _get_num_alleles(H, s) - e = self.haplotype_emission( - mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation - ) - yield n, m, H, s, e, r, mu - - # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.2, np.zeros(m) + 1e-6, np.random.rand(m) * 0.2] - - e = self.haplotype_emission( - mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation - ) - - for s, r, mu in itertools.product(haplotypes, rs, mus): - r[0] = 0 - n_alleles = _get_num_alleles(H, s) - e = self.haplotype_emission( - mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation - ) - yield n, m, H, s, e, r, mu - - def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" - assert np.allclose(A, B, rtol=1e-9, atol=0.0) - - # Define a bunch of very small tree-sequences for testing a collection of parameters on - def test_simple_n_10_no_recombination(self): - ts = msprime.sim_ancestry( - samples=10, - recombination_rate=0, - random_seed=42, - sequence_length=10, - population_size=10000, - ) - ts = msprime.sim_mutations(ts, rate=1e-5, random_seed=42) - assert ts.num_sites > 3 - self.verify(ts) - - def test_simple_n_6(self): - ts = msprime.sim_ancestry( - samples=6, - recombination_rate=1e-4, - random_seed=42, - sequence_length=40, - population_size=10000, - ) - ts = msprime.sim_mutations(ts, rate=1e-3, random_seed=42) - assert ts.num_sites > 5 - self.verify(ts) - - def test_simple_n_8(self): - ts = msprime.sim_ancestry( - samples=8, - recombination_rate=1e-4, - random_seed=42, - sequence_length=20, - population_size=10000, - ) - ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) - assert ts.num_sites > 5 - assert ts.num_trees > 15 - self.verify(ts) - - def test_simple_n_16(self): - ts = msprime.sim_ancestry( - samples=16, - recombination_rate=1e-2, - random_seed=42, - sequence_length=20, - population_size=10000, - ) - ts = msprime.sim_mutations(ts, rate=1e-4, random_seed=42) - assert ts.num_sites > 5 - self.verify(ts) - - def verify(self, ts): - raise NotImplementedError() - - -class FBAlgorithmBase(LSBase): - """Base for forwards backwards algorithm tests.""" - - -class TestMethodsHap(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" +class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts): - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts): F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r) B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) F, c, ll = ls.forwards(H_vs, s, r, p_mutation=mu) B = ls.backwards(H_vs, s, c, r, p_mutation=mu) self.assertAllClose(F, F_vs) self.assertAllClose(B, B_vs) - # print(e_vs) self.assertAllClose(ll_vs, ll) - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes( + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid( ts, scale_mutation=False ): F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r) @@ -205,23 +30,48 @@ def verify(self, ts): self.assertAllClose(B, B_vs) self.assertAllClose(ll_vs, ll) + def test_multiallelic_n10_no_recombination(self): + ts = self.get_multiallelic_n10_no_recombination() + self.verify(ts) + + def test_multiallelic_n6(self): + ts = self.get_multiallelic_n6() + self.verify(ts) -class VitAlgorithmBase(LSBase): - """Base for viterbi algoritm tests.""" + def test_multiallelic_n8(self): + ts = self.get_multiallelic_n8() + self.verify(ts) + def test_multiallelic_n16(self): + ts = self.get_multiallelic_n16() + self.verify(ts) -class TestViterbiHap(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" +class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts): - for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): + for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts): V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling( n, m, H_vs, s, e_vs, r ) path_vs = vh.backwards_viterbi_hap(m, V_vs, P_vs) path_ll_hap = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) path, ll = ls.viterbi(H_vs, s, r, p_mutation=mu) - self.assertAllClose(ll_vs, ll) self.assertAllClose(ll_vs, path_ll_hap) self.assertAllClose(path_vs, path) + + def test_multiallelic_n10_no_recombination(self): + ts = self.get_multiallelic_n10_no_recombination() + self.verify(ts) + + def test_multiallelic_n6(self): + ts = self.get_multiallelic_n6() + self.verify(ts) + + def test_multiallelic_n8(self): + ts = self.get_multiallelic_n8() + self.verify(ts) + + def test_multiallelic_n16(self): + ts = self.get_multiallelic_n16() + self.verify(ts) diff --git a/tests/test_LS_haploid_diploid.py b/tests/test_LS_haploid_diploid.py deleted file mode 100644 index 9b9f7d8..0000000 --- a/tests/test_LS_haploid_diploid.py +++ /dev/null @@ -1,607 +0,0 @@ -# Simulation -import itertools - -# Python libraries -import msprime -import numpy as np -import pytest - -import lshmm.forward_backward.fb_diploid as fbd -import lshmm.forward_backward.fb_haploid as fbh -import lshmm.vit_diploid as vd -import lshmm.vit_haploid as vh - -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - -import numba as nb -import tskit - - -class LSBase: - """Superclass of Li and Stephens tests.""" - - def example_haplotypes(self, ts): - H = ts.genotype_matrix() - s = H[:, 0].reshape(1, H.shape[0]) - H = H[:, 1:] - - haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] - s_tmp = s.copy() - s_tmp[0, -1] = MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING - haplotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, :] = MISSING - haplotypes.append(s_tmp) - - return H, haplotypes - - def haplotype_emission(self, mu, m): - # Define the emission probability matrix - e = np.zeros((m, 2)) - e[:, 0] = mu # If they match - e[:, 1] = 1 - mu # If they don't match - - return e - - def genotype_emission(self, mu, m): - # Define the emission probability matrix - e = np.zeros((m, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2 - e[:, UNEQUAL_BOTH_HOM] = mu**2 - e[:, BOTH_HET] = 1 - mu - e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu) - e[:, REF_HET_OBS_HOM] = mu * (1 - mu) - e[:, MISSING_INDEX] = 1 - - return e - - def example_parameters_haplotypes(self, ts, seed=42): - """Returns an iterator over combinations of haplotype, recombination and mutation probabilities.""" - np.random.seed(seed) - H, haplotypes = self.example_haplotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - # Here we have equal mutation and recombination - r = np.zeros(m) + 0.01 - mu = np.zeros(m) + 0.01 - r[0] = 0 - - e = self.haplotype_emission(mu, m) - - for s in haplotypes: - yield n, m, H, s, e, r - - # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - - e = self.haplotype_emission(mu, m) - - for s, r, mu in itertools.product(haplotypes, rs, mus): - r[0] = 0 - e = self.haplotype_emission(mu, m) - yield n, m, H, s, e, r - - def example_parameters_haplotypes_larger( - self, ts, seed=42, mean_r=1e-5, mean_mu=1e-5 - ): - np.random.seed(seed) - H, haplotypes = self.example_haplotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - r[0] = 0 - - # Error probability - mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - - # Define the emission probability matrix - e = self.haplotype_emission(mu, m) - - for s in haplotypes: - yield n, m, H, s, e, r - - def example_genotypes(self, ts, seed=42): - H = ts.genotype_matrix() - s = H[:, 0].reshape(1, H.shape[0]) + H[:, 1].reshape(1, H.shape[0]) - H = H[:, 2:] - - genotypes = [ - s, - H[:, -1].reshape(1, H.shape[0]) + H[:, -2].reshape(1, H.shape[0]), - ] - - s_tmp = s.copy() - s_tmp[0, -1] = MISSING - genotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING - genotypes.append(s_tmp) - s_tmp = s.copy() - s_tmp[0, :] = MISSING - genotypes.append(s_tmp) - - m = ts.get_num_sites() - n = H.shape[1] - - G = np.zeros((m, n, n)) - for i in range(m): - G[i, :, :] = np.add.outer(H[i, :], H[i, :]) - - return H, G, genotypes - - def example_parameters_genotypes(self, ts, seed=42): - np.random.seed(seed) - H, G, genotypes = self.example_genotypes(ts) - n = H.shape[1] - m = ts.get_num_sites() - - # Here we have equal mutation and recombination - r = np.zeros(m) + 0.01 - mu = np.zeros(m) + 0.01 - r[0] = 0 - - e = self.genotype_emission(mu, m) - - for s in genotypes: - yield n, m, G, s, e, r - - # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] - - for s, r, mu in itertools.product(genotypes, rs, mus): - r[0] = 0 - e = self.genotype_emission(mu, m) - yield n, m, G, s, e, r - - def example_parameters_genotypes_larger( - self, ts, seed=42, mean_r=1e-5, mean_mu=1e-5 - ): - np.random.seed(seed) - H, G, genotypes = self.example_genotypes(ts) - - m = ts.get_num_sites() - n = H.shape[1] - - r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - r[0] = 0 - - # Error probability - mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - - # Define the emission probability matrix - e = self.genotype_emission(mu, m) - - for s in genotypes: - yield n, m, G, s, e, r - - def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" - # assert np.allclose(A, B, rtol=1e-9, atol=0.0) - assert np.allclose(A, B, rtol=1e-09, atol=1e-08) - - # Define a bunch of very small tree-sequences for testing a collection of parameters on - def test_simple_n_10_no_recombination(self): - ts = msprime.simulate( - 10, recombination_rate=0, mutation_rate=0.5, random_seed=42 - ) - assert ts.num_sites > 3 - self.verify(ts) - - def test_simple_n_6(self): - ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42) - assert ts.num_sites > 5 - self.verify(ts) - - def test_simple_n_8(self): - ts = msprime.simulate(8, recombination_rate=2, mutation_rate=5, random_seed=42) - assert ts.num_sites > 5 - self.verify(ts) - - def test_simple_n_8_high_recombination(self): - ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42) - assert ts.num_trees > 15 - assert ts.num_sites > 5 - self.verify(ts) - - def test_simple_n_16(self): - ts = msprime.simulate(16, recombination_rate=2, mutation_rate=5, random_seed=42) - assert ts.num_sites > 5 - self.verify(ts) - - # Test a bigger one. - def test_large(self, n=50, length=100000, mean_r=1e-5, mean_mu=1e-5, seed=42): - ts = msprime.simulate( - n + 1, - length=length, - mutation_rate=mean_mu, - recombination_rate=mean_r, - random_seed=seed, - ) - self.verify_larger(ts) - - def verify(self, ts): - raise NotImplementedError() - - def verify_larger(self, ts): - pass - - -class FBAlgorithmBase(LSBase): - """Base for forwards backwards algorithm tests.""" - - -class TestNonTreeMethodsHap(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" - - def verify(self, ts): - for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes(ts): - e_sv = e_vs.T - H_sv = H_vs.T - - # variants x samples - F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) - B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) - F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap( - n, m, H_vs, s, e_vs, r, norm=True - ) - B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) - - def verify_larger(self, ts): - # variants x samples - for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes_larger(ts): - e_sv = e_vs.T - H_sv = H_vs.T - - F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) - B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) - F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap( - n, m, H_vs, s, e_vs, r, norm=True - ) - B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) - - -class TestNonTreeMethodsDip(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" - - def verify(self, ts): - for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes(ts): - F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) - B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) - F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip( - n, m, G_vs, s, e_vs, r, norm=False - ) - if ll_tmp != -np.inf: - B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=False - ) - if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=True - ) - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) - - def verify_larger(self, ts): - for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes_larger(ts): - F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) - B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) - self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) - F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip( - n, m, G_vs, s, e_vs, r, norm=False - ) - if ll_tmp != -np.inf: - B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=False - ) - if ll_tmp != -np.inf: - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose( - np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) - ) - F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( - n, m, G_vs, s, e_vs, r, norm=True - ) - B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) - self.assertAllClose(ll_vs, ll_tmp) - self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) - - -class VitAlgorithmBase(LSBase): - """Base for viterbi algoritm tests.""" - - -class TestNonTreeViterbiHap(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" - - def verify(self, ts): - for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes(ts): - V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r) - path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs) - ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_check) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - - ( - V_tmp, - V_argmaxes_tmp, - recombs, - ll_tmp, - ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap_no_pointer( - m, - V_argmaxes_tmp, - nb.typed.List(recombs), - ) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - - def verify_larger(self, ts): - for n, m, H_vs, s, e_vs, r in self.example_parameters_haplotypes_larger(ts): - V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r) - path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs) - ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, ll_check) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - - ( - V_tmp, - V_argmaxes_tmp, - recombs, - ll_tmp, - ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( - n, m, H_vs, s, e_vs, r - ) - path_tmp = vh.backwards_viterbi_hap_no_pointer( - m, V_argmaxes_tmp, nb.typed.List(recombs) - ) - ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, ll_check) - self.assertAllClose(ll_vs, ll_tmp) - - -class TestNonTreeViterbiDip(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" - - def verify(self, ts): - for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes(ts): - V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r) - path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs) - phased_path_vs = vd.get_phased_path(n, path_vs) - path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, path_ll_vs) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - ( - V_tmp, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - recombs_single, - recombs_double, - ll_tmp, - ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r) - path_tmp = vd.backwards_viterbi_dip_no_pointer( - m, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - nb.typed.List(recombs_single), - nb.typed.List(recombs_double), - V_tmp, - ) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - def verify_larger(self, ts): - for n, m, G_vs, s, e_vs, r in self.example_parameters_genotypes_larger(ts): - V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r) - path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs) - phased_path_vs = vd.get_phased_path(n, path_vs) - path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) - self.assertAllClose(ll_vs, path_ll_vs) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - ( - V_tmp, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - recombs_single, - recombs_double, - ll_tmp, - ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r) - path_tmp = vd.backwards_viterbi_dip_no_pointer( - m, - V_argmaxes_tmp, - V_rowcol_maxes_tmp, - V_rowcol_argmaxes_tmp, - nb.typed.List(recombs_single), - nb.typed.List(recombs_double), - V_tmp, - ) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) - - V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec( - n, m, G_vs, s, e_vs, r - ) - path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) - phased_path_tmp = vd.get_phased_path(n, path_tmp) - path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) - self.assertAllClose(ll_tmp, path_ll_tmp) - self.assertAllClose(ll_vs, ll_tmp) diff --git a/tests/test_non_tree.py b/tests/test_non_tree.py new file mode 100644 index 0000000..fa4c975 --- /dev/null +++ b/tests/test_non_tree.py @@ -0,0 +1,362 @@ +import numpy as np +import numba as nb + +from . import lsbase +import lshmm.fb_diploid as fbd +import lshmm.fb_haploid as fbh +import lshmm.vit_diploid as vd +import lshmm.vit_haploid as vh + + +class TestNonTreeForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): + def verify(self, n, m, H_vs, s, e_vs, r): + F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False) + B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r) + self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m)) + + F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=True) + B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r) + self.assertAllClose(ll_vs, ll_tmp) + self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m)) + + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n6(self): + ts = self.get_simple_n6() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8(self): + ts = self.get_simple_n8() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n16(self): + ts = self.get_simple_n16() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_larger(self): + seed = 42 + num_samples = 50 + seq_length = 1e5 + mean_r = 1e-5 + mean_mu = 1e-5 + ts = self.get_larger( + num_samples, + seq_length, + mean_r, + mean_mu, + seed=seed, + ) + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid( + ts, + mean_r=mean_r, + mean_mu=mean_mu, + seed=seed, + ): + self.verify(n, m, H_vs, s, e_vs, r) + + +class TestNonTreeForwardBackwardDiploid(lsbase.ForwardBackwardAlgorithmBase): + def verify(self, n, m, G_vs, s, e_vs, r): + F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True) + B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r) + self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m)) + + F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=False) + if ll_tmp != -np.inf: + B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r) + self.assertAllClose(ll_vs, ll_tmp) + self.assertAllClose( + np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) + ) + + F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) + if ll_tmp != -np.inf: + B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r) + self.assertAllClose(ll_vs, ll_tmp) + self.assertAllClose( + np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) + ) + + F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( + n, m, G_vs, s, e_vs, r, norm=False + ) + if ll_tmp != -np.inf: + B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) + self.assertAllClose(ll_vs, ll_tmp) + self.assertAllClose( + np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m) + ) + + F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop( + n, m, G_vs, s, e_vs, r, norm=True + ) + B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r) + self.assertAllClose(ll_vs, ll_tmp) + self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m)) + + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n6(self): + ts = self.get_simple_n6() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8(self): + ts = self.get_simple_n8() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n16(self): + ts = self.get_simple_n16() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_larger(self): + seed = 42 + num_samples = 50 + seq_length = 1e5 + mean_r = 1e-5 + mean_mu = 1e-5 + ts = self.get_larger( + num_samples, + seq_length, + mean_r, + mean_mu, + seed=seed, + ) + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid( + ts, + mean_r=mean_r, + mean_mu=mean_mu, + seed=seed, + ): + self.verify(n, m, H_vs, s, e_vs, r) + + +class TestNonTreeViterbiHaploid(lsbase.ViterbiAlgorithmBase): + def verify(self, n, m, H_vs, s, e_vs, r): + V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r) + path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs) + ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r) + self.assertAllClose(ll_vs, ll_check) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec(n, m, H_vs, s, e_vs, r) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + ( + V_tmp, + V_argmaxes_tmp, + recombs, + ll_tmp, + ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( + n, m, H_vs, s, e_vs, r + ) + path_tmp = vh.backwards_viterbi_hap_no_pointer( + m, + V_argmaxes_tmp, + nb.typed.List(recombs), + ) + ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n6(self): + ts = self.get_simple_n6() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8(self): + ts = self.get_simple_n8() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n16(self): + ts = self.get_simple_n16() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_larger(self): + seed = 42 + num_samples = 50 + seq_length = 1e5 + mean_r = 1e-5 + mean_mu = 1e-5 + ts = self.get_larger( + num_samples, + seq_length, + mean_r, + mean_mu, + seed=seed, + ) + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_haploid( + ts, + mean_r=mean_r, + mean_mu=mean_mu, + seed=seed, + ): + self.verify(n, m, H_vs, s, e_vs, r) + + +class TestNonTreeViterbiDiploid(lsbase.ViterbiAlgorithmBase): + def verify(self, n, m, G_vs, s, e_vs, r): + V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_naive(n, m, G_vs, s, e_vs, r) + path_vs = vd.backwards_viterbi_dip(m, V_vs[m - 1, :, :], P_vs) + phased_path_vs = vd.get_phased_path(n, path_vs) + path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r) + self.assertAllClose(ll_vs, path_ll_vs) + + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem( + n, m, G_vs, s, e_vs, r + ) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + ( + V_tmp, + V_argmaxes_tmp, + V_rowcol_maxes_tmp, + V_rowcol_argmaxes_tmp, + recombs_single, + recombs_double, + ll_tmp, + ) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r) + path_tmp = vd.backwards_viterbi_dip_no_pointer( + m, + V_argmaxes_tmp, + V_rowcol_maxes_tmp, + V_rowcol_argmaxes_tmp, + nb.typed.List(recombs_single), + nb.typed.List(recombs_double), + V_tmp, + ) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec(n, m, G_vs, s, e_vs, r) + path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp) + phased_path_tmp = vd.get_phased_path(n, path_tmp) + path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r) + self.assertAllClose(ll_tmp, path_ll_tmp) + self.assertAllClose(ll_vs, ll_tmp) + + def test_simple_n10_no_recombination(self): + ts = self.get_simple_n10_no_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n6(self): + ts = self.get_simple_n6() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8(self): + ts = self.get_simple_n8() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n8_high_recombination(self): + ts = self.get_simple_n8_high_recombination() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_simple_n16(self): + ts = self.get_simple_n16() + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid(ts): + self.verify(n, m, H_vs, s, e_vs, r) + + def test_larger(self): + seed = 42 + num_samples = 50 + seq_length = 1e5 + mean_r = 1e-5 + mean_mu = 1e-5 + ts = self.get_larger( + num_samples, + seq_length, + mean_r, + mean_mu, + seed=seed, + ) + for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars_diploid( + ts, + mean_r=mean_r, + mean_mu=mean_mu, + seed=seed, + ): + self.verify(n, m, H_vs, s, e_vs, r)