From ee017b486d653be011647d2353c75611861699a7 Mon Sep 17 00:00:00 2001 From: szhan Date: Fri, 19 Apr 2024 09:02:34 +0100 Subject: [PATCH] Refactor --- lshmm/api.py | 98 +++++------ lshmm/core.py | 74 +++++++++ lshmm/forward_backward/fb_diploid.py | 181 ++++++++------------- lshmm/forward_backward/fb_haploid.py | 54 +++--- lshmm/vit_diploid.py | 235 ++++++++++----------------- lshmm/vit_haploid.py | 148 +++++++++-------- tests/test_API.py | 62 +++---- tests/test_API_multiallelic.py | 26 +-- tests/test_LS_haploid_diploid.py | 48 ++---- 9 files changed, 439 insertions(+), 487 deletions(-) create mode 100644 lshmm/core.py diff --git a/lshmm/api.py b/lshmm/api.py index f917ecf..9fe1692 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -4,6 +4,7 @@ import numpy as np +from . import core from .forward_backward.fb_diploid import backward_ls_dip_loop, forward_ls_dip_loop from .forward_backward.fb_haploid import backwards_ls_hap, forwards_ls_hap from .vit_diploid import ( @@ -18,15 +19,6 @@ path_ll_hap, ) -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 -MISSING_INDEX = 3 - -MISSING = -1 - def check_alleles(alleles, m): """ @@ -52,7 +44,7 @@ def check_alleles(alleles, m): if isinstance(alleles[0], str): return np.int8([len(alleles) for _ in range(m)]) # Otherwise, process allele lists. - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) n_alleles = np.zeros(num_sites, dtype=np.int8) for i in range(num_sites): uniq_alleles = np.unique(alleles[i]) @@ -68,7 +60,7 @@ def checks( scale_mutation_based_on_n_alleles, ): """ - Checks that the input data and parameters are valid. + Check that the input data and parameters are valid. The reference panel must be a matrix of size (m, n) or (m, n, n). The query must be a matrix of size (k, m) or (k, m, 2). @@ -77,17 +69,21 @@ def checks( n = number of samples in the reference panel (haplotypes, not individuals). k = number of samples in the query (haplotypes, not individuals). + The mutation rate can be scaled according to the set of alleles + that can be mutated to based on the number of distinct alleles at each site. + :param numpy.ndarray(dtype=int) reference_panel: Matrix of size (m, n) or (m, n, n). :param numpy.ndarray(dtype=int) query: Matrix of size (k, m) or (k, m, 2). :param numpy.ndarray(dtype=float) p_mutation: Scalar or vector of length m. :param numpy.ndarray(dtype=float) p_recombination: Scalar or vector of length m. - :param bool scale_mutation_based_on_n_alleles: Whether to scale the mutation probability to the set of alleles that can be mutated to based on the number of alleles (True) or not (False). + :param bool scale_mutation_based_on_n_alleles: Scale the mutation probability or not. :return: n, m, ploidy :rtype: tuple """ # Check reference panel if not len(reference_panel.shape) in (2, 3): - raise ValueError("Reference panel array must have 2 or 3 dimensions.") + err_msg = "Reference panel array must have 2 or 3 dimensions." + raise ValueError(err_msg) if len(reference_panel.shape) == 2: m, n = reference_panel.shape @@ -97,42 +93,49 @@ def checks( ploidy = 2 if ploidy == 2 and (reference_panel.shape[1] != reference_panel.shape[2]): - raise ValueError( - "Reference_panel dimensions are incorrect, perhaps a sample x sample x variant matrix was passed. Expected sites x samples x samples." + err_msg = ( + "Reference_panel dimensions are incorrect, " + "perhaps a sample x sample x variant matrix was passed. " + "Expected sites x samples x samples." ) + raise ValueError(err_msg) # Check query sequence(s) if query.shape[1] != m: - raise ValueError( - "Number of sites in query does not match reference panel. If haploid, ensure a sites x samples matrix is passed." + err_msg = ( + "Number of sites in query does not match reference panel. " + "If haploid, ensure a sites x samples matrix is passed." ) + raise ValueError(err_msg) - # Ensure that the mutation probability is either a scalar or vector of length m + # Ensure that the mutation probability is either a scalar or vector of length m. if isinstance(p_mutation, (int, float)): if not scale_mutation_based_on_n_alleles: - warnings.warn( - "Passed a scalar probability of mutation, but not rescaling this probability of mutation conditional on the number of alleles at the site." - ) + warn_msg = "Passed a scalar mutation probability, but not rescaling it." + warnings.warn(warn_msg) elif isinstance(p_mutation, np.ndarray) and p_mutation.shape[0] == m: if scale_mutation_based_on_n_alleles: - warnings.warn( - "Passed a vector of probabilities of mutation, but rescaling each mutation probability conditional on the number of alleles at each site." - ) + warn_msg = "Passed a vector of mutation probabilities. Rescaling them." + warnings.warn(warn_msg) elif p_mutation is None: - warnings.warn( - "No mutation probability passed, setting mutation probability based on Li and Stephens 2003, equations (A2) and (A3)" + warn_msg = ( + "No mutation probability passed. " + "Setting it based on Li & Stephens (2003) equations A2 and A3." ) + warnings.warn(warn_msg) else: - raise ValueError( - f"Mutation probability is not None, a scalar, or vector of length m: {m}" + err_msg = ( + f"Mutation probability is not None, a scalar, or vector of length {m}." ) + raise ValueError(err_msg) # Ensure that the recombination probability is either a scalar or a vector of length m if not ( isinstance(p_recombination, (int, float)) or (isinstance(p_recombination, np.ndarray) and p_recombination.shape[0] == m) ): - raise ValueError(f"p_Recombination is not a scalar or vector of length m: {m}") + err_msg = f"Recombination probability is not a scalar or vector of length {m}." + raise ValueError(err_msg) return (n, m, ploidy) @@ -148,9 +151,10 @@ def set_emission_probabilities( scale_mutation_based_on_n_alleles, ): # Check alleles should go in here, and modify e before passing to the algorithm - # If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel. + # If alleles is not passed, we don't perform a test of alleles, + # but set n_alleles based on the reference_panel. if alleles is None: - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) n_alleles = np.zeros(m, dtype=np.int8) for j in range(reference_panel.shape[0]): uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j])) @@ -159,7 +163,7 @@ def set_emission_probabilities( n_alleles = check_alleles(alleles, m) if p_mutation is None: - # Set the mutation probability to be the proposed mutation probability in Li and Stephens (2003). + # Set the mutation probability to be the proposed mutation probability in Li & Stephens (2003). theta_tilde = 1 / np.sum([1 / k for k in range(1, n - 1)]) p_mutation = 0.5 * (theta_tilde / (n + theta_tilde)) @@ -172,15 +176,18 @@ def set_emission_probabilities( e = np.zeros((m, 2)) if scale_mutation_based_on_n_alleles: - # Scale mutation based on the number of alleles - so p_mutation is probability of mutation any given one of the alleles. + # Scale mutation based on the number of alleles, + # so p_mutation is probability of mutation any given one of the alleles. # The overall mutation probability is then (n_alleles - 1) * p_mutation. e[:, 0] = p_mutation - p_mutation * np.equal( n_alleles, np.ones(m) ) # Added boolean in case we're at an invariant site e[:, 1] = 1 - (n_alleles - 1) * p_mutation else: - # No scaling based on the number of alleles - so p_mutation is the probability of mutation to anything - # (summing over the states we can switch to). This means that we must rescale the probability of mutation to + # No scaling based on the number of alleles, + # so p_mutation is the probability of mutation to anything + # (summing over the states we can switch to). + # This means that we must rescale the probability of mutation to # a different allele by the number of alleles at the site. for j in range(m): if n_alleles[j] == 1: # In case we're at an invariant site @@ -194,12 +201,12 @@ def set_emission_probabilities( # Evaluate emission probabilities here, using the mutation probability - this can take a scalar or vector. # DEV: there's a wrinkle here. e = np.zeros((m, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2 - e[:, UNEQUAL_BOTH_HOM] = p_mutation**2 - e[:, BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation**2 - e[:, REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation) - e[:, REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation) - e[:, MISSING_INDEX] = 1 + e[:, core.EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2 + e[:, core.UNEQUAL_BOTH_HOM] = p_mutation**2 + e[:, core.BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation**2 + e[:, core.REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation) + e[:, core.REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation) + e[:, core.MISSING_INDEX] = 1 return e @@ -233,8 +240,7 @@ def forwards( norm=True, ): """ - Run the Li and Stephens forwards algorithm on haplotype or - unphased genotype data. + Run the Li & Stephens forwards algorithm on haplotype or unphased genotype data. """ n, m, ploidy = checks( reference_panel, @@ -281,8 +287,7 @@ def backwards( scale_mutation_based_on_n_alleles=True, ): """ - Run the Li and Stephens backwards algorithm on haplotype or - unphased genotype data. + Run the Li & Stephens backwards algorithm on haplotype or unphased genotype data. """ n, m, ploidy = checks( reference_panel, @@ -330,8 +335,7 @@ def viterbi( scale_mutation_based_on_n_alleles=True, ): """ - Run the Li and Stephens Viterbi algorithm on haplotype or - unphased genotype data. + Run the Li & Stephens Viterbi algorithm on haplotype or unphased genotype data. """ n, m, ploidy = checks( reference_panel, diff --git a/lshmm/core.py b/lshmm/core.py new file mode 100644 index 0000000..dd822a1 --- /dev/null +++ b/lshmm/core.py @@ -0,0 +1,74 @@ +import numpy as np + +from lshmm import jit + + +EQUAL_BOTH_HOM = 4 +UNEQUAL_BOTH_HOM = 0 +BOTH_HET = 7 +REF_HOM_OBS_HET = 1 +REF_HET_OBS_HOM = 2 +MISSING_INDEX = 3 + +MISSING = -1 + + +""" Helper functions. """ + + +# https://github.com/numba/numba/issues/1269 +@jit.numba_njit +def np_apply_along_axis(func1d, axis, arr): + """Create numpy-like functions for max, sum etc.""" + assert arr.ndim == 2 + assert axis in [0, 1] + if axis == 0: + result = np.empty(arr.shape[1]) + for i in range(len(result)): + result[i] = func1d(arr[:, i]) + else: + result = np.empty(arr.shape[0]) + for i in range(len(result)): + result[i] = func1d(arr[i, :]) + return result + + +@jit.numba_njit +def np_amax(array, axis): + """Numba implementation of numpy vectorised maximum.""" + return np_apply_along_axis(np.amax, axis, array) + + +@jit.numba_njit +def np_sum(array, axis): + """Numba implementation of numpy vectorised sum.""" + return np_apply_along_axis(np.sum, axis, array) + + +@jit.numba_njit +def np_argmax(array, axis): + """Numba implementation of numpy vectorised argmax.""" + return np_apply_along_axis(np.argmax, axis, array) + + +""" Functions used across different implementations of the LS HMM. """ + + +@jit.numba_njit +def get_index_in_emission_prob_matrix(ref_allele, query_allele): + is_allele_match = np.equal(ref_allele, query_allele) + is_query_missing = query_allele == MISSING + if is_allele_match or is_query_missing: + return 1 + return 0 + + +@jit.numba_njit +def get_index_in_emission_prob_matrix_diploid(ref_allele, query_allele): + if query_allele == MISSING: + return MISSING_INDEX + else: + is_allele_match = ref_allele == query_allele + is_ref_one = ref_allele == 1 + is_query_one = query_allele == 1 + return 4 * is_allele_match + 2 * is_ref_one + is_query_one diff --git a/lshmm/forward_backward/fb_diploid.py b/lshmm/forward_backward/fb_diploid.py index 50ffe12..334c816 100644 --- a/lshmm/forward_backward/fb_diploid.py +++ b/lshmm/forward_backward/fb_diploid.py @@ -1,64 +1,25 @@ -"""Collection of functions to run forwards and backwards algorithms on haploid genotype data, where the data is structured as variants x samples.""" +""" +Various implementations of the Li & Stephens forwards-backwards algorithm on diploid genotype data, +where the data is structured as variants x samples x samples. +""" + import numpy as np +from lshmm import core from lshmm import jit -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - - -# https://github.com/numba/numba/issues/1269 -@jit.numba_njit -def np_apply_along_axis(func1d, axis, arr): - """Create numpy-like functions for max, sum etc.""" - assert arr.ndim == 2 - assert axis in [0, 1] - if axis == 0: - result = np.empty(arr.shape[1]) - for i in range(len(result)): - result[i] = func1d(arr[:, i]) - else: - result = np.empty(arr.shape[0]) - for i in range(len(result)): - result[i] = func1d(arr[i, :]) - return result - - -@jit.numba_njit -def np_amax(array, axis): - """Numba implementation of numpy vectorised maximum.""" - return np_apply_along_axis(np.amax, axis, array) - - -@jit.numba_njit -def np_sum(array, axis): - """Numba implementation of numpy vectorised sum.""" - return np_apply_along_axis(np.sum, axis, array) - - -@jit.numba_njit -def np_argmax(array, axis): - """Numba implementation of numpy vectorised argmax.""" - return np_apply_along_axis(np.argmax, axis, array) - def forwards_ls_dip(n, m, G, s, e, r, norm=True): - """Matrix based diploid LS forward algorithm using numpy vectorisation.""" - # Initialise the forward tensor + """A matrix-based implementation using Numpy vectorisation.""" + # Initialise F = np.zeros((m, n, n)) F[0, :, :] = 1 / (n**2) c = np.ones(m) r_n = r / n - if s[0, 0] == MISSING: - index = MISSING_INDEX * np.ones( + if s[0, 0] == core.MISSING: + index = core.MISSING_INDEX * np.ones( (n, n), dtype=np.int64 ) # We could have chosen anything here, this just implies a multiplication by a constant. else: @@ -76,8 +37,8 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): # Forwards for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) + 2 * ( G[l, :, :] == 1 @@ -93,7 +54,7 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): F[l, :, :] += (r_n[l]) ** 2 # One changes - sum_j = np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T + sum_j = core.np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T F[l, :, :] += ((1 - r[l]) * r_n[l]) * (sum_j + sum_j.T) # Emission @@ -105,8 +66,8 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): else: # Forwards for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) + 2 * ( G[l, :, :] == 1 @@ -122,7 +83,7 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): F[l, :, :] += (r_n[l]) ** 2 * np.sum(F[l - 1, :, :]) # One changes - sum_j = np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T + sum_j = core.np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T # sum_j2 = np_sum(F[l - 1, :, :], 1).repeat(n).reshape((-1, n)) F[l, :, :] += ((1 - r[l]) * r_n[l]) * (sum_j + sum_j.T) @@ -135,18 +96,16 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): def backwards_ls_dip(n, m, G, s, e, c, r): - """Matrix based diploid LS backward algorithm using numpy vectorisation.""" - # Initialise the backward tensor - B = np.zeros((m, n, n)) - + """A matrix-based implementation using Numpy vectorisation.""" # Initialise + B = np.zeros((m, n, n)) B[m - 1, :, :] = 1 r_n = r / n # Backwards for l in range(m - 2, -1, -1): - if s[0, l + 1] == MISSING: - index = MISSING_INDEX * np.ones( + if s[0, l + 1] == core.MISSING: + index = core.MISSING_INDEX * np.ones( (n, n), dtype=np.int64 ) # We could have chosen anything here, this just implies a multiplication by a constant. else: @@ -167,7 +126,9 @@ def backwards_ls_dip(n, m, G, s, e, c, r): ) # One changes - sum_j = np_sum(B[l + 1, :, :] * e[l + 1, index], 0).repeat(n).reshape((-1, n)) + sum_j = ( + core.np_sum(B[l + 1, :, :] * e[l + 1, index], 0).repeat(n).reshape((-1, n)) + ) B[l, :, :] += ((1 - r[l + 1]) * r_n[l + 1]) * (sum_j + sum_j.T) B[l, :, :] *= 1 / c[l + 1] @@ -177,14 +138,15 @@ def backwards_ls_dip(n, m, G, s, e, c, r): @jit.numba_njit def forward_ls_dip_starting_point(n, m, G, s, e, r): """Naive implementation of LS diploid forwards algorithm.""" - # Initialise the forward tensor + # Initialise F = np.zeros((m, n, n)) r_n = r / n + for j1 in range(n): for j2 in range(n): F[0, j1, j2] = 1 / (n**2) - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX + if s[0, 0] == core.MISSING: + index_tmp = core.MISSING_INDEX else: index_tmp = ( 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) @@ -194,7 +156,6 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r): F[0, j1, j2] *= e[0, index_tmp] for l in range(1, m): - # Determine the various components F_no_change = np.zeros((n, n)) F_j1_change = np.zeros(n) F_j2_change = np.zeros(n) @@ -231,24 +192,24 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): # What is the emission? - if s[0, l] == MISSING: - F[l, j1, j2] *= e[l, MISSING_INDEX] + if s[0, l] == core.MISSING: + F[l, j1, j2] *= e[l, core.MISSING_INDEX] else: if s[0, l] == 1: # OBS is het if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, BOTH_HET] + F[l, j1, j2] *= e[l, core.BOTH_HET] else: # REF is hom - F[l, j1, j2] *= e[l, REF_HOM_OBS_HET] + F[l, j1, j2] *= e[l, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, REF_HET_OBS_HOM] + F[l, j1, j2] *= e[l, core.REF_HET_OBS_HOM] else: # REF is hom if G[l, j1, j2] == s[0, l]: # Equal - F[l, j1, j2] *= e[l, EQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.EQUAL_BOTH_HOM] else: # Unequal - F[l, j1, j2] *= e[l, UNEQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.UNEQUAL_BOTH_HOM] ll = np.log10(np.sum(F[l, :, :])) @@ -257,16 +218,13 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r): @jit.numba_njit def backward_ls_dip_starting_point(n, m, G, s, e, r): - """Naive implementation of LS diploid backwards algorithm.""" - # Backwards - B = np.zeros((m, n, n)) - + """A naive implementation.""" # Initialise + B = np.zeros((m, n, n)) B[m - 1, :, :] = 1 r_n = r / n for l in range(m - 2, -1, -1): - # Determine the various components B_no_change = np.zeros((n, n)) B_j1_change = np.zeros(n) B_j2_change = np.zeros(n) @@ -274,8 +232,8 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): # Evaluate the emission matrix at this site, for all pairs e_tmp = np.zeros((n, n)) - if s[0, l + 1] == MISSING: - e_tmp[:, :] = e[l + 1, MISSING_INDEX] + if s[0, l + 1] == core.MISSING: + e_tmp[:, :] = e[l + 1, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -283,18 +241,18 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): if s[0, l + 1] == 1: # OBS is het if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, BOTH_HET] + e_tmp[j1, j2] = e[l + 1, core.BOTH_HET] else: # REF is hom - e_tmp[j1, j2] = e[l + 1, REF_HOM_OBS_HET] + e_tmp[j1, j2] = e[l + 1, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, REF_HET_OBS_HOM] + e_tmp[j1, j2] = e[l + 1, core.REF_HET_OBS_HOM] else: # REF is hom if G[l + 1, j1, j2] == s[0, l + 1]: # Equal - e_tmp[j1, j2] = e[l + 1, EQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.EQUAL_BOTH_HOM] else: # Unequal - e_tmp[j1, j2] = e[l + 1, UNEQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.UNEQUAL_BOTH_HOM] for j1 in range(n): for j2 in range(n): @@ -336,13 +294,13 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): @jit.numba_njit def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): """LS diploid forwards algoritm without vectorisation.""" - # Initialise the forward tensor + # Initialise F = np.zeros((m, n, n)) for j1 in range(n): for j2 in range(n): F[0, j1, j2] = 1 / (n**2) - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX + if s[0, 0] == core.MISSING: + index_tmp = core.MISSING_INDEX else: index_tmp = ( 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) @@ -358,7 +316,6 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): F[0, :, :] *= 1 / c[0] for l in range(1, m): - # Determine the various components F_no_change = np.zeros((n, n)) F_j_change = np.zeros(n) @@ -375,8 +332,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): for j2 in range(n): F[l, j1, j2] += F_no_change[j1, j2] - if s[0, l] == MISSING: - F[l, :, :] *= e[l, MISSING_INDEX] + if s[0, l] == core.MISSING: + F[l, :, :] *= e[l, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -384,18 +341,18 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): if s[0, l] == 1: # OBS is het if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, BOTH_HET] + F[l, j1, j2] *= e[l, core.BOTH_HET] else: # REF is hom - F[l, j1, j2] *= e[l, REF_HOM_OBS_HET] + F[l, j1, j2] *= e[l, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, REF_HET_OBS_HOM] + F[l, j1, j2] *= e[l, core.REF_HET_OBS_HOM] else: # REF is hom if G[l, j1, j2] == s[0, l]: # Equal - F[l, j1, j2] *= e[l, EQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.EQUAL_BOTH_HOM] else: # Unequal - F[l, j1, j2] *= e[l, UNEQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.UNEQUAL_BOTH_HOM] c[l] = np.sum(F[l, :, :]) F[l, :, :] *= 1 / c[l] @@ -404,7 +361,6 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): else: for l in range(1, m): - # Determine the various components F_no_change = np.zeros((n, n)) F_j1_change = np.zeros(n) F_j2_change = np.zeros(n) @@ -425,8 +381,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): for j2 in range(n): F[l, j1, j2] += F_no_change[j1, j2] - if s[0, l] == MISSING: - F[l, :, :] *= e[l, MISSING_INDEX] + if s[0, l] == core.MISSING: + F[l, :, :] *= e[l, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -434,18 +390,18 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): if s[0, l] == 1: # OBS is het if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, BOTH_HET] + F[l, j1, j2] *= e[l, core.BOTH_HET] else: # REF is hom - F[l, j1, j2] *= e[l, REF_HOM_OBS_HET] + F[l, j1, j2] *= e[l, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, REF_HET_OBS_HOM] + F[l, j1, j2] *= e[l, core.REF_HET_OBS_HOM] else: # REF is hom if G[l, j1, j2] == s[0, l]: # Equal - F[l, j1, j2] *= e[l, EQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.EQUAL_BOTH_HOM] else: # Unequal - F[l, j1, j2] *= e[l, UNEQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.UNEQUAL_BOTH_HOM] ll = np.log10(np.sum(F[l, :, :])) @@ -455,13 +411,12 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): @jit.numba_njit def backward_ls_dip_loop(n, m, G, s, e, c, r): """LS diploid backwards algoritm without vectorisation.""" - # Initialise the backward tensor + # Initialise B = np.zeros((m, n, n)) B[m - 1, :, :] = 1 r_n = r / n for l in range(m - 2, -1, -1): - # Determine the various components B_no_change = np.zeros((n, n)) B_j1_change = np.zeros(n) B_j2_change = np.zeros(n) @@ -469,8 +424,8 @@ def backward_ls_dip_loop(n, m, G, s, e, c, r): # Evaluate the emission matrix at this site, for all pairs e_tmp = np.zeros((n, n)) - if s[0, l + 1] == MISSING: - e_tmp[:, :] = e[l + 1, MISSING_INDEX] + if s[0, l + 1] == core.MISSING: + e_tmp[:, :] = e[l + 1, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -479,18 +434,18 @@ def backward_ls_dip_loop(n, m, G, s, e, c, r): if s[0, l + 1] == 1: # OBS is het if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, BOTH_HET] + e_tmp[j1, j2] = e[l + 1, core.BOTH_HET] else: # REF is hom - e_tmp[j1, j2] = e[l + 1, REF_HOM_OBS_HET] + e_tmp[j1, j2] = e[l + 1, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, REF_HET_OBS_HOM] + e_tmp[j1, j2] = e[l + 1, core.REF_HET_OBS_HOM] else: # REF is hom if G[l + 1, j1, j2] == s[0, l + 1]: # Equal - e_tmp[j1, j2] = e[l + 1, EQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.EQUAL_BOTH_HOM] else: # Unequal - e_tmp[j1, j2] = e[l + 1, UNEQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.UNEQUAL_BOTH_HOM] for j1 in range(n): for j2 in range(n): diff --git a/lshmm/forward_backward/fb_haploid.py b/lshmm/forward_backward/fb_haploid.py index 69d01fc..37f8dbf 100644 --- a/lshmm/forward_backward/fb_haploid.py +++ b/lshmm/forward_backward/fb_haploid.py @@ -1,25 +1,28 @@ -"""Collection of functions to run forwards and backwards algorithms on haploid genotype data, where the data is structured as variants x samples.""" +""" +Various implementations of the Li & Stephens forwards-backwards algorithm on haploid genotype data, +where the data is structured as variants x samples. +""" import numpy as np +from lshmm import core from lshmm import jit -MISSING = -1 - @jit.numba_njit def forwards_ls_hap(n, m, H, s, e, r, norm=True): - """Matrix based haploid LS forward algorithm using numpy vectorisation.""" - # Initialise + """A matrix-based implementation using Numpy vectorisation.""" F = np.zeros((m, n)) r_n = r / n if norm: c = np.zeros(m) for i in range(n): - F[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[0, i], + query_allele=s[0, 0] ) + F[0, i] = 1 / n * e[0, emission_idx] c[0] += F[0, i] for i in range(n): @@ -29,9 +32,11 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l] - F[l, i] *= e[ - l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING) - ] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[l, i], + query_allele=s[0, l] + ) + F[l, i] *= e[l, emission_idx] c[l] += F[l, i] for i in range(n): @@ -43,17 +48,21 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): c = np.ones(m) for i in range(n): - F[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[0, i], + query_allele=s[0, 0] ) + F[0, i] = 1 / n * e[0, emission_idx] # Forwards pass for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l] - F[l, i] *= e[ - l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING) - ] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[l, i], + query_allele=s[0, l] + ) + F[l, i] *= e[l, emission_idx] ll = np.log10(np.sum(F[m - 1, :])) @@ -62,8 +71,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): @jit.numba_njit def backwards_ls_hap(n, m, H, s, e, c, r): - """Matrix based haploid LS backward algorithm using numpy vectorisation.""" - # Initialise + """A matrix-based implementation using Numpy vectorisation.""" B = np.zeros((m, n)) for i in range(n): B[m - 1, i] = 1 @@ -74,15 +82,11 @@ def backwards_ls_hap(n, m, H, s, e, c, r): tmp_B = np.zeros(n) tmp_B_sum = 0 for i in range(n): - tmp_B[i] = ( - e[ - l + 1, - np.int64( - np.equal(H[l + 1, i], s[0, l + 1]) or s[0, l + 1] == MISSING - ), - ] - * B[l + 1, i] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[l + 1, i], + query_allele=s[0, l + 1] ) + tmp_B[i] = e[l + 1, emission_idx] * B[l + 1, i] tmp_B_sum += tmp_B[i] for i in range(n): B[l, i] = r_n[l + 1] * tmp_B_sum diff --git a/lshmm/vit_diploid.py b/lshmm/vit_diploid.py index 316b5d0..e1c0304 100644 --- a/lshmm/vit_diploid.py +++ b/lshmm/vit_diploid.py @@ -1,72 +1,33 @@ -"""Collection of functions to run Viterbi algorithms on dipoid genotype data, where the data is structured as variants x samples.""" +""" +Various implementations of the Li & Stephens Viterbi algorithm on diploid genotype data, +where the data is structured as variants x samples x samples. +""" import numpy as np +from . import core from . import jit -MISSING = -1 -MISSING_INDEX = 3 - - -# https://github.com/numba/numba/issues/1269 -@jit.numba_njit -def np_apply_along_axis(func1d, axis, arr): - """Create numpy-like functions for max, sum etc.""" - assert arr.ndim == 2 - assert axis in [0, 1] - if axis == 0: - result = np.empty(arr.shape[1]) - for i in range(len(result)): - result[i] = func1d(arr[:, i]) - else: - result = np.empty(arr.shape[0]) - for i in range(len(result)): - result[i] = func1d(arr[i, :]) - return result - - -@jit.numba_njit -def np_amax(array, axis): - """Numba implementation of numpy vectorised maximum.""" - return np_apply_along_axis(np.amax, axis, array) - - -@jit.numba_njit -def np_sum(array, axis): - """Numba implementation of numpy vectorised sum.""" - return np_apply_along_axis(np.sum, axis, array) - - -@jit.numba_njit -def np_argmax(array, axis): - """Numba implementation of numpy vectorised argmax.""" - return np_apply_along_axis(np.argmax, axis, array) - @jit.numba_njit def forwards_viterbi_dip_naive(n, m, G, s, e, r): - """Naive implementation of LS diploid Viterbi algorithm.""" + """A naive implementation.""" # Initialise V = np.zeros((m, n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) r_n = r / n for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V[0, j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_prob_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V[0, j1, j2] = 1 / (n**2) * e[0, emission_index] for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -101,31 +62,27 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): - """Naive implementation of LS diploid Viterbi algorithm, with reduced memory.""" + """A naive implementation with reduced memory.""" # Initialise V = np.zeros((n, n)) - V_previous = np.zeros((n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + V_prev = np.zeros((n, n)) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) r_n = r / n for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_prob_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V_prev[j1, j2] = 1 / (n**2) * e[0, emission_index] - # Take a look at Haploid Viterbi implementation in Jeromes code and see if we can pinch some ideas. + # Take a look at the haploid Viterbi implementation in Jerome's code, and + # see if we can pinch some ideas. # Diploid Viterbi, with smaller memory footprint. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -138,7 +95,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): v = np.zeros((n, n)) for k1 in range(n): for k2 in range(n): - v[k1, k2] = V_previous[k1, k2] + v[k1, k2] = V_prev[k1, k2] if (k1 == j1) and (k2 == j2): v[k1, k2] *= ( (1 - r[l]) ** 2 + 2 * (1 - r[l]) * r_n[l] + r_n[l] ** 2 @@ -150,7 +107,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): V[j1, j2] = np.amax(v) * e[l, index[j1, j2]] P[l, j1, j2] = np.argmax(v) c[l] = np.amax(V) - V_previous = np.copy(V) / c[l] + V_prev = np.copy(V) / c[l] ll = np.sum(np.log10(c)) @@ -159,30 +116,25 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): - """LS diploid Viterbi algorithm, with reduced memory.""" + """An implementation with reduced memory.""" # Initialise V = np.zeros((n, n)) - V_previous = np.zeros((n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + V_prev = np.zeros((n, n)) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) r_n = r / n for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_prob_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V_prev[j1, j2] = 1 / (n**2) * e[0, emission_index] # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -190,12 +142,12 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): + np.int64(s[0, l] == 1) ) - c[l] = np.amax(V_previous) - argmax = np.argmax(V_previous) + c[l] = np.amax(V_prev) + argmax = np.argmax(V_prev) - V_previous *= 1 / c[l] - V_rowcol_max = np_amax(V_previous, 0) - arg_rowcol_max = np_argmax(V_previous, 0) + V_prev *= 1 / c[l] + V_rowcol_max = core.np_amax(V_prev, 0) + arg_rowcol_max = core.np_argmax(V_prev, 0) no_switch = (1 - r[l]) ** 2 + 2 * (r_n[l] * (1 - r[l])) + r_n[l] ** 2 single_switch = r_n[l] * (1 - r[l]) + r_n[l] ** 2 @@ -215,7 +167,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): else: template_single_switch = arg_rowcol_max[j2] * n + j2 - V[j1, j2] = V_previous[j1, j2] * no_switch # No switch in either + V[j1, j2] = V_prev[j1, j2] * no_switch # No switch in either P[l, j1, j2] = j1_j2 # Single or double switch? @@ -233,7 +185,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): V[j1, j2] *= e[l, index[j1, j2]] j1_j2 += 1 - V_previous = np.copy(V) + V_prev = np.copy(V) ll = np.sum(np.log10(c)) + np.log10(np.amax(V)) @@ -242,10 +194,10 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): - """LS diploid Viterbi algorithm, with reduced memory.""" + """An implementation with reduced memory and no pointer.""" # Initialise V = np.zeros((n, n)) - V_previous = np.zeros((n, n)) + V_prev = np.zeros((n, n)) c = np.ones(m) r_n = r / n @@ -262,20 +214,15 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_prob_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V_prev[j1, j2] = 1 / (n**2) * e[0, emission_index] # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -283,14 +230,14 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): + np.int64(s[0, l] == 1) ) - c[l] = np.amax(V_previous) - argmax = np.argmax(V_previous) + c[l] = np.amax(V_prev) + argmax = np.argmax(V_prev) V_argmaxes[l - 1] = argmax # added - V_previous *= 1 / c[l] - V_rowcol_max = np_amax(V_previous, 0) + V_prev *= 1 / c[l] + V_rowcol_max = core.np_amax(V_prev, 0) V_rowcol_maxes[l - 1, :] = V_rowcol_max - arg_rowcol_max = np_argmax(V_previous, 0) + arg_rowcol_max = core.np_argmax(V_prev, 0) V_rowcol_argmaxes[l - 1, :] = arg_rowcol_max no_switch = (1 - r[l]) ** 2 + 2 * (r_n[l] * (1 - r[l])) + r_n[l] ** 2 @@ -302,7 +249,7 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): V_single_switch = max(V_rowcol_max[j1], V_rowcol_max[j2]) - V[j1, j2] = V_previous[j1, j2] * no_switch # No switch in either + V[j1, j2] = V_prev[j1, j2] * no_switch # No switch in either # Single or double switch? single_switch_tmp = single_switch * V_single_switch @@ -319,11 +266,11 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): V[j1, j2] *= e[l, index[j1, j2]] j1_j2 += 1 - V_previous = np.copy(V) + V_prev = np.copy(V) - V_argmaxes[m - 1] = np.argmax(V_previous) - V_rowcol_maxes[m - 1, :] = np_amax(V_previous, 0) - V_rowcol_argmaxes[m - 1, :] = np_argmax(V_previous, 0) + V_argmaxes[m - 1] = np.argmax(V_prev) + V_rowcol_maxes[m - 1, :] = core.np_amax(V_prev, 0) + V_rowcol_argmaxes[m - 1, :] = core.np_argmax(V_prev, 0) ll = np.sum(np.log10(c)) + np.log10(np.amax(V)) return ( @@ -339,29 +286,24 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): - """Vectorised LS diploid Viterbi algorithm using numpy.""" + """An implementation using Numpy vectorisation.""" # Initialise V = np.zeros((m, n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) r_n = r / n for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX - else: - index_tmp = ( - 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) - + 2 * np.int64((G[0, j1, j2] == 1)) - + np.int64(s[0, 0] == 1) - ) - V[0, j1, j2] = 1 / (n**2) * e[0, index_tmp] + emission_index = core.get_index_in_emission_prob_matrix_diploid( + ref_allele=G[0, j1, j2], query_allele=s[0, 0] + ) + V[0, j1, j2] = 1 / (n**2) * e[0, emission_index] # Jumped the gun - vectorising. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -396,10 +338,11 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): # Initialise V = np.zeros((m, n, n)) - P = np.zeros((m, n, n)).astype(np.int64) + P = np.zeros((m, n, n), dtype=np.int64) c = np.ones(m) - if s[0, 0] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + + if s[0, 0] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[0, :, :], s[0, 0]).astype(np.int64) @@ -410,8 +353,8 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): r_n = r / n for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -439,8 +382,9 @@ def backwards_viterbi_dip(m, V_last, P): """Run a backwards pass to determine the most likely path.""" assert V_last.ndim == 2 assert V_last.shape[0] == V_last.shape[1] - # Initialisation - path = np.zeros(m).astype(np.int64) + + # Initialise + path = np.zeros(m, dtype=np.int64) path[m - 1] = np.argmax(V_last) # Backtrace @@ -455,8 +399,7 @@ def in_list(array, value): where = np.searchsorted(array, value) if where < array.shape[0]: return array[where] == value - else: - return False + return False @jit.numba_njit @@ -472,8 +415,9 @@ def backwards_viterbi_dip_no_pointer( """Run a backwards pass to determine the most likely path.""" assert V_last.ndim == 2 assert V_last.shape[0] == V_last.shape[1] - # Initialisation - path = np.zeros(m).astype(np.int64) + + # Initialise + path = np.zeros(m, dtype=np.int64) path[m - 1] = np.argmax(V_last) n = V_last.shape[0] @@ -503,21 +447,16 @@ def get_phased_path(n, path): @jit.numba_njit def path_ll_dip(n, m, G, phased_path, s, e, r): """Evaluate log-likelihood path through a reference panel which results in sequence s.""" - if s[0, 0] == MISSING: - index = MISSING_INDEX - else: - index = ( - 4 * np.int64(np.equal(G[0, phased_path[0][0], phased_path[1][0]], s[0, 0])) - + 2 * np.int64(G[0, phased_path[0][0], phased_path[1][0]] == 1) - + np.int64(s[0, 0] == 1) - ) - log_prob_path = np.log10(1 / (n**2) * e[0, index]) + emission_index = core.get_index_in_emission_prob_matrix_diploid( + ref_allele=G[0, phased_path[0][0], phased_path[1][0]], query_allele=s[0, 0] + ) + log_prob_path = np.log10(1 / (n**2) * e[0, emission_index]) old_phase = np.array([phased_path[0][0], phased_path[1][0]]) r_n = r / n for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX else: index = ( 4 diff --git a/lshmm/vit_haploid.py b/lshmm/vit_haploid.py index 7fec45e..f0d537d 100644 --- a/lshmm/vit_haploid.py +++ b/lshmm/vit_haploid.py @@ -1,57 +1,60 @@ -"""Collection of functions to run Viterbi algorithms on haploid genotype data, where the data is structured as variants x samples.""" +""" +Various implementations of the Li & Stephens Viterbi algorithm on haploid genotype data, +where the data is structured as variants x samples. +""" import numpy as np +from . import core from . import jit -MISSING = -1 - @jit.numba_njit def viterbi_naive_init(n, m, H, s, e, r): - """Initialise naive implementation of LS viterbi.""" + """Initialise a naive implementation.""" V = np.zeros((m, n)) - P = np.zeros((m, n)).astype(np.int64) + P = np.zeros((m, n), dtype=np.int64) r_n = r / n + for i in range(n): - V[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] ) + V[0, i] = 1 / n * e[0, emission_idx] return V, P, r_n @jit.numba_njit def viterbi_init(n, m, H, s, e, r): - """Initialise naive, but more space memory efficient implementation of LS viterbi.""" - V_previous = np.zeros(n) + """Initialise a naive, but more memory efficient, implementation.""" + V_prev = np.zeros(n) V = np.zeros(n) - P = np.zeros((m, n)).astype(np.int64) + P = np.zeros((m, n), dtype=np.int64) r_n = r / n for i in range(n): - V_previous[i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] ) + V_prev[i] = 1 / n * e[0, emission_idx] - return V, V_previous, P, r_n + return V, V_prev, P, r_n @jit.numba_njit def forwards_viterbi_hap_naive(n, m, H, s, e, r): - """Naive implementation of LS haploid Viterbi algorithm.""" - # Initialise + """A naive implementation of the forward pass.""" V, P, r_n = viterbi_naive_init(n, m, H, s, e, r) for j in range(1, m): for i in range(n): - # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V[j - 1, k] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[j, i], query_allele=s[0, j] ) + v[k] = V[j - 1, k] * e[j, emission_idx] if k == i: v[k] *= 1 - r[j] + r_n[j] else: @@ -66,8 +69,7 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): - """Naive matrix based implementation of LS haploid forward Viterbi algorithm using numpy.""" - # Initialise + """A naive matrix-based implementation of the forward pass using Numpy.""" V, P, r_n = viterbi_naive_init(n, m, H, s, e, r) for j in range(1, m): @@ -75,7 +77,10 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): for i in range(n): v = np.copy(v_tmp) v[i] += V[j - 1, i] * (1 - r[j]) - v *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[j, i], query_allele=s[0, j] + ) + v *= e[j, emission_idx] P[j, i] = np.argmax(v) V[j, i] = v[P[j, i]] @@ -86,26 +91,24 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): - """Naive implementation of LS haploid Viterbi algorithm, with reduced memory.""" - # Initialise - V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r) + """A naive implementation of the forward pass with reduced memory.""" + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r) for j in range(1, m): for i in range(n): - # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V_previous[k] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[j, i], query_allele=s[0, j] ) + v[k] = V_prev[k] * e[j, emission_idx] if k == i: v[k] *= 1 - r[j] + r_n[j] else: v[k] *= r_n[j] P[j, i] = np.argmax(v) V[i] = v[P[j, i]] - V_previous = np.copy(V) + V_prev = np.copy(V) ll = np.log10(np.amax(V)) @@ -114,30 +117,27 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): - """Naive implementation of LS haploid Viterbi algorithm, with reduced memory and rescaling.""" - # Initialise - V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r) + """A naive implementation of the forward pass with reduced memory and rescaling.""" + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r) c = np.ones(m) for j in range(1, m): - c[j] = np.amax(V_previous) - V_previous *= 1 / c[j] + c[j] = np.amax(V_prev) + V_prev *= 1 / c[j] for i in range(n): - # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V_previous[k] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[j, i], query_allele=s[0, j] ) + v[k] = V_prev[k] * e[j, emission_idx] if k == i: v[k] *= 1 - r[j] + r_n[j] else: v[k] *= r_n[j] P[j, i] = np.argmax(v) V[i] = v[P[j, i]] - - V_previous = np.copy(V) + V_prev = np.copy(V) ll = np.sum(np.log10(c)) + np.log10(np.amax(V)) @@ -146,24 +146,26 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): - """LS haploid Viterbi algorithm, with reduced memory and exploits the Markov process structure.""" - # Initialise - V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r) + """An implementation with reduced memory that exploits the Markov structure.""" + V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r) c = np.ones(m) for j in range(1, m): - argmax = np.argmax(V_previous) - c[j] = V_previous[argmax] - V_previous *= 1 / c[j] + argmax = np.argmax(V_prev) + c[j] = V_prev[argmax] + V_prev *= 1 / c[j] V = np.zeros(n) for i in range(n): - V[i] = V_previous[i] * (1 - r[j] + r_n[j]) + V[i] = V_prev[i] * (1 - r[j] + r_n[j]) P[j, i] = i if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - V_previous = np.copy(V) + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[j, i], query_allele=s[0, j] + ) + V[i] *= e[j, emission_idx] + V_prev = np.copy(V) ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -172,12 +174,14 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): - """LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure.""" - # Initialise + """An implementation with even smaller memory footprint that exploits the Markov structure.""" V = np.zeros(n) for i in range(n): - V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - P = np.zeros((m, n)).astype(np.int64) + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] + ) + V[i] = 1 / n * e[0, emission_idx] + P = np.zeros((m, n), dtype=np.int64) r_n = r / n c = np.ones(m) @@ -191,7 +195,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[j, i], query_allele=s[0, j] + ) + V[i] *= e[j, emission_idx] ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -200,11 +207,13 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): - """LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure.""" - # Initialise + """An implementation with even smaller memory footprint and rescaling that exploits the Markov structure.""" V = np.zeros(n) for i in range(n): - V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[0, i], query_allele=s[0, 0] + ) + V[i] = 1 / n * e[0, emission_idx] r_n = r / n c = np.ones(m) recombs = [ @@ -225,7 +234,8 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): recombs[j] = np.append( recombs[j], i ) # We add template i as a potential template to recombine to at site j. - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix(H[j, i], s[0, j]) + V[i] *= e[j, emission_idx] V_argmaxes[m - 1] = np.argmax(V) ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -237,9 +247,8 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): @jit.numba_njit def backwards_viterbi_hap(m, V_last, P): """Run a backwards pass to determine the most likely path.""" - # Initialise assert len(V_last.shape) == 1 - path = np.zeros(m).astype(np.int64) + path = np.zeros(m, dtype=np.int64) path[m - 1] = np.argmax(V_last) for j in range(m - 2, -1, -1): @@ -251,8 +260,7 @@ def backwards_viterbi_hap(m, V_last, P): @jit.numba_njit def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs): """Run a backwards pass to determine the most likely path.""" - # Initialise - path = np.zeros(m).astype(np.int64) + path = np.zeros(m, dtype=np.int64) path[m - 1] = V_argmaxes[m - 1] for j in range(m - 2, -1, -1): @@ -266,14 +274,18 @@ def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs): @jit.numba_njit def path_ll_hap(n, m, H, path, s, e, r): - """Evaluate log-likelihood path through a reference panel which results in sequence s.""" - index = np.int64(np.equal(H[0, path[0]], s[0, 0]) or s[0, 0] == MISSING) - log_prob_path = np.log10((1 / n) * e[0, index]) + """Evaluate the log-likelihood of a path through a reference panel resulting in a sequence.""" + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[0, path[0]], query_allele=s[0, 0] + ) + log_prob_path = np.log10((1 / n) * e[0, emission_idx]) old = path[0] r_n = r / n for l in range(1, m): - index = np.int64(np.equal(H[l, path[l]], s[0, l]) or s[0, l] == MISSING) + emission_idx = core.get_index_in_emission_prob_matrix( + ref_allele=H[l, path[l]], query_allele=s[0, l] + ) current = path[l] same = old == current @@ -282,7 +294,7 @@ def path_ll_hap(n, m, H, path, s, e, r): else: log_prob_path += np.log10(r_n[l]) - log_prob_path += np.log10(e[l, index]) + log_prob_path += np.log10(e[l, emission_idx]) old = current return log_prob_path diff --git a/tests/test_API.py b/tests/test_API.py index 129e67a..6b869ab 100644 --- a/tests/test_API.py +++ b/tests/test_API.py @@ -1,27 +1,18 @@ -# Simulation import itertools +import pytest -# Python libraries -import msprime import numpy as np -import pytest + +import msprime import tskit import lshmm as ls +import lshmm.core as core import lshmm.forward_backward.fb_diploid as fbd import lshmm.forward_backward.fb_haploid as fbh import lshmm.vit_diploid as vd import lshmm.vit_haploid as vh -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - class LSBase: """Superclass of Li and Stephens tests.""" @@ -33,13 +24,13 @@ def example_haplotypes(self, ts, seed=42): haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING haplotypes.append(s_tmp) return H, haplotypes @@ -68,18 +59,18 @@ def haplotype_emission(self, mu, m, n_alleles, scale_mutation_based_on_n_alleles def genotype_emission(self, mu, m): # Define the emission probability matrix e = np.zeros((m, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2 - e[:, UNEQUAL_BOTH_HOM] = mu**2 - e[:, BOTH_HET] = (1 - mu) ** 2 + mu**2 - e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu) - e[:, REF_HET_OBS_HOM] = mu * (1 - mu) - e[:, MISSING_INDEX] = 1 - + e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2 + e[:, core.UNEQUAL_BOTH_HOM] = mu**2 + e[:, core.BOTH_HET] = (1 - mu) ** 2 + mu**2 + e[:, core.REF_HOM_OBS_HET] = 2 * mu * (1 - mu) + e[:, core.REF_HET_OBS_HOM] = mu * (1 - mu) + e[:, core.MISSING_INDEX] = 1 return e def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): - """Returns an iterator over combinations of haplotype, recombination and - mutation probabilities.""" + """ + Returns an iterator over combinations of haplotype, recombination and mutation probabilities. + """ np.random.seed(seed) H, haplotypes = self.example_haplotypes(ts) n = H.shape[1] @@ -89,7 +80,7 @@ def _get_num_alleles(ref_haps, query): assert ref_haps.shape[0] == query.shape[1] num_sites = ref_haps.shape[0] num_alleles = np.zeros(num_sites, dtype=np.int8) - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) for i in range(num_sites): uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) @@ -132,13 +123,13 @@ def example_genotypes(self, ts, seed=42): ] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING genotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING genotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING genotypes.append(s_tmp) m = ts.get_num_sites() @@ -189,17 +180,14 @@ def example_parameters_genotypes_larger( r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) r[0] = 0 - # Error probability mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - # Define the emission probability matrix e = self.genotype_emission(mu, m) for s in genotypes: yield n, m, G, s, e, r, mu def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" assert np.allclose(A, B, rtol=1e-9, atol=0.0) # Define a bunch of very small tree-sequences for testing a collection of parameters on @@ -240,7 +228,7 @@ class FBAlgorithmBase(LSBase): class TestMethodsHap(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" + """Test that the computed likelihood is the same across all implementations.""" def verify(self, ts): for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): @@ -257,7 +245,7 @@ def verify(self, ts): class TestMethodsDip(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" + """Test that the computed likelihood is the same across all implementations.""" def verify(self, ts): for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts): @@ -273,11 +261,11 @@ def verify(self, ts): class VitAlgorithmBase(LSBase): - """Base for viterbi algoritm tests.""" + """Base for Viterbi algoritm tests.""" class TestViterbiHap(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" + """Test that the computed log-likelihood is the same across all implementations.""" def verify(self, ts): for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): @@ -292,7 +280,7 @@ def verify(self, ts): class TestViterbiDip(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" + """Test that the computed log-likelihood is the same across all implementations.""" def verify(self, ts): for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts): diff --git a/tests/test_API_multiallelic.py b/tests/test_API_multiallelic.py index 92f1ab3..1add27e 100644 --- a/tests/test_API_multiallelic.py +++ b/tests/test_API_multiallelic.py @@ -1,27 +1,17 @@ -# Simulation import itertools -# Python libraries import msprime import numpy as np import pytest import tskit import lshmm as ls +import lshmm.core as core import lshmm.forward_backward.fb_diploid as fbd import lshmm.forward_backward.fb_haploid as fbh import lshmm.vit_diploid as vd import lshmm.vit_haploid as vh -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - class LSBase: """Superclass of Li and Stephens tests.""" @@ -33,13 +23,13 @@ def example_haplotypes(self, ts, num_random=10, seed=42): haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING haplotypes.append(s_tmp) return H, haplotypes @@ -77,7 +67,7 @@ def _get_num_alleles(ref_haps, query): assert ref_haps.shape[0] == query.shape[1] num_sites = ref_haps.shape[0] num_alleles = np.zeros(num_sites, dtype=np.int8) - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) for i in range(num_sites): uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) @@ -115,7 +105,6 @@ def _get_num_alleles(ref_haps, query): yield n, m, H, s, e, r, mu def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" assert np.allclose(A, B, rtol=1e-9, atol=0.0) # Define a bunch of very small tree-sequences for testing a collection of parameters on @@ -177,7 +166,7 @@ class FBAlgorithmBase(LSBase): class TestMethodsHap(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" + """Test that we compute the same likelihoods across all implementations.""" def verify(self, ts): for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): @@ -187,7 +176,6 @@ def verify(self, ts): B = ls.backwards(H_vs, s, c, r, p_mutation=mu) self.assertAllClose(F, F_vs) self.assertAllClose(B, B_vs) - # print(e_vs) self.assertAllClose(ll_vs, ll) for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes( @@ -211,7 +199,7 @@ class VitAlgorithmBase(LSBase): class TestViterbiHap(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" + """Test that we have the same log-likelihood across all implementations.""" def verify(self, ts): for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): diff --git a/tests/test_LS_haploid_diploid.py b/tests/test_LS_haploid_diploid.py index 9b9f7d8..ae4ee9a 100644 --- a/tests/test_LS_haploid_diploid.py +++ b/tests/test_LS_haploid_diploid.py @@ -1,28 +1,18 @@ -# Simulation import itertools +import pytest -# Python libraries -import msprime import numpy as np -import pytest +import numba as nb + +import msprime +import tskit +import lshmm.core as core import lshmm.forward_backward.fb_diploid as fbd import lshmm.forward_backward.fb_haploid as fbh import lshmm.vit_diploid as vd import lshmm.vit_haploid as vh -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - -import numba as nb -import tskit - class LSBase: """Superclass of Li and Stephens tests.""" @@ -34,13 +24,13 @@ def example_haplotypes(self, ts): haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING haplotypes.append(s_tmp) return H, haplotypes @@ -50,19 +40,17 @@ def haplotype_emission(self, mu, m): e = np.zeros((m, 2)) e[:, 0] = mu # If they match e[:, 1] = 1 - mu # If they don't match - return e def genotype_emission(self, mu, m): # Define the emission probability matrix e = np.zeros((m, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2 - e[:, UNEQUAL_BOTH_HOM] = mu**2 - e[:, BOTH_HET] = 1 - mu - e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu) - e[:, REF_HET_OBS_HOM] = mu * (1 - mu) - e[:, MISSING_INDEX] = 1 - + e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2 + e[:, core.UNEQUAL_BOTH_HOM] = mu**2 + e[:, core.BOTH_HET] = 1 - mu + e[:, core.REF_HOM_OBS_HET] = 2 * mu * (1 - mu) + e[:, core.REF_HET_OBS_HOM] = mu * (1 - mu) + e[:, core.MISSING_INDEX] = 1 return e def example_parameters_haplotypes(self, ts, seed=42): @@ -124,13 +112,13 @@ def example_genotypes(self, ts, seed=42): ] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING genotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING genotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING genotypes.append(s_tmp) m = ts.get_num_sites()