diff --git a/lshmm/api.py b/lshmm/api.py index f917ecf..6cfb663 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -4,6 +4,7 @@ import numpy as np +from . import core from .forward_backward.fb_diploid import backward_ls_dip_loop, forward_ls_dip_loop from .forward_backward.fb_haploid import backwards_ls_hap, forwards_ls_hap from .vit_diploid import ( @@ -18,15 +19,6 @@ path_ll_hap, ) -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 -MISSING_INDEX = 3 - -MISSING = -1 - def check_alleles(alleles, m): """ @@ -52,7 +44,7 @@ def check_alleles(alleles, m): if isinstance(alleles[0], str): return np.int8([len(alleles) for _ in range(m)]) # Otherwise, process allele lists. - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) n_alleles = np.zeros(num_sites, dtype=np.int8) for i in range(num_sites): uniq_alleles = np.unique(alleles[i]) @@ -68,7 +60,7 @@ def checks( scale_mutation_based_on_n_alleles, ): """ - Checks that the input data and parameters are valid. + Check that the input data and parameters are valid. The reference panel must be a matrix of size (m, n) or (m, n, n). The query must be a matrix of size (k, m) or (k, m, 2). @@ -77,17 +69,21 @@ def checks( n = number of samples in the reference panel (haplotypes, not individuals). k = number of samples in the query (haplotypes, not individuals). + The mutation rate can be scaled according to the set of alleles + that can be mutated to based on the number of distinct alleles at each site. + :param numpy.ndarray(dtype=int) reference_panel: Matrix of size (m, n) or (m, n, n). :param numpy.ndarray(dtype=int) query: Matrix of size (k, m) or (k, m, 2). :param numpy.ndarray(dtype=float) p_mutation: Scalar or vector of length m. :param numpy.ndarray(dtype=float) p_recombination: Scalar or vector of length m. - :param bool scale_mutation_based_on_n_alleles: Whether to scale the mutation probability to the set of alleles that can be mutated to based on the number of alleles (True) or not (False). + :param bool scale_mutation_based_on_n_alleles: Scale the mutation probability or not. :return: n, m, ploidy :rtype: tuple """ # Check reference panel if not len(reference_panel.shape) in (2, 3): - raise ValueError("Reference panel array must have 2 or 3 dimensions.") + err_msg = "Reference panel array must have 2 or 3 dimensions." + raise ValueError(err_msg) if len(reference_panel.shape) == 2: m, n = reference_panel.shape @@ -97,42 +93,47 @@ def checks( ploidy = 2 if ploidy == 2 and (reference_panel.shape[1] != reference_panel.shape[2]): - raise ValueError( - "Reference_panel dimensions are incorrect, perhaps a sample x sample x variant matrix was passed. Expected sites x samples x samples." + err_msg = ( + "Reference_panel dimensions are incorrect, " + "perhaps a sample x sample x variant matrix was passed. " + "Expected sites x samples x samples." ) + raise ValueError(err_msg) # Check query sequence(s) if query.shape[1] != m: - raise ValueError( - "Number of sites in query does not match reference panel. If haploid, ensure a sites x samples matrix is passed." + err_msg = ( + "Number of sites in query does not match reference panel. " + "If haploid, ensure a sites x samples matrix is passed." ) + raise ValueError(err_msg) - # Ensure that the mutation probability is either a scalar or vector of length m + # Ensure that the mutation probability is either a scalar or vector of length m. if isinstance(p_mutation, (int, float)): if not scale_mutation_based_on_n_alleles: - warnings.warn( - "Passed a scalar probability of mutation, but not rescaling this probability of mutation conditional on the number of alleles at the site." - ) + warn_msg = "Passed a scalar mutation probability, but not rescaling it." + warnings.warn(warn_msg) elif isinstance(p_mutation, np.ndarray) and p_mutation.shape[0] == m: if scale_mutation_based_on_n_alleles: - warnings.warn( - "Passed a vector of probabilities of mutation, but rescaling each mutation probability conditional on the number of alleles at each site." - ) + warn_msg = "Passed a vector of mutation probabilities. Rescaling them." + warnings.warn(warn_msg) elif p_mutation is None: - warnings.warn( - "No mutation probability passed, setting mutation probability based on Li and Stephens 2003, equations (A2) and (A3)" + warn_msg = ( + "No mutation probability passed. " + "Setting it based on Li & Stephens (2003) equations A2 and A3." ) + warnings.warn(warn_msg) else: - raise ValueError( - f"Mutation probability is not None, a scalar, or vector of length m: {m}" - ) + err_msg = f"Mutation probability is not None, a scalar, or vector of length {m}." + raise ValueError(err_msg) # Ensure that the recombination probability is either a scalar or a vector of length m if not ( isinstance(p_recombination, (int, float)) or (isinstance(p_recombination, np.ndarray) and p_recombination.shape[0] == m) ): - raise ValueError(f"p_Recombination is not a scalar or vector of length m: {m}") + err_msg = f"Recombination probability is not a scalar or vector of length {m}." + raise ValueError(err_msg) return (n, m, ploidy) @@ -148,9 +149,10 @@ def set_emission_probabilities( scale_mutation_based_on_n_alleles, ): # Check alleles should go in here, and modify e before passing to the algorithm - # If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel. + # If alleles is not passed, we don't perform a test of alleles, + # but set n_alleles based on the reference_panel. if alleles is None: - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) n_alleles = np.zeros(m, dtype=np.int8) for j in range(reference_panel.shape[0]): uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j])) @@ -159,7 +161,7 @@ def set_emission_probabilities( n_alleles = check_alleles(alleles, m) if p_mutation is None: - # Set the mutation probability to be the proposed mutation probability in Li and Stephens (2003). + # Set the mutation probability to be the proposed mutation probability in Li & Stephens (2003). theta_tilde = 1 / np.sum([1 / k for k in range(1, n - 1)]) p_mutation = 0.5 * (theta_tilde / (n + theta_tilde)) @@ -172,15 +174,18 @@ def set_emission_probabilities( e = np.zeros((m, 2)) if scale_mutation_based_on_n_alleles: - # Scale mutation based on the number of alleles - so p_mutation is probability of mutation any given one of the alleles. + # Scale mutation based on the number of alleles, + # so p_mutation is probability of mutation any given one of the alleles. # The overall mutation probability is then (n_alleles - 1) * p_mutation. e[:, 0] = p_mutation - p_mutation * np.equal( n_alleles, np.ones(m) ) # Added boolean in case we're at an invariant site e[:, 1] = 1 - (n_alleles - 1) * p_mutation else: - # No scaling based on the number of alleles - so p_mutation is the probability of mutation to anything - # (summing over the states we can switch to). This means that we must rescale the probability of mutation to + # No scaling based on the number of alleles, + # so p_mutation is the probability of mutation to anything + # (summing over the states we can switch to). + # This means that we must rescale the probability of mutation to # a different allele by the number of alleles at the site. for j in range(m): if n_alleles[j] == 1: # In case we're at an invariant site @@ -194,12 +199,12 @@ def set_emission_probabilities( # Evaluate emission probabilities here, using the mutation probability - this can take a scalar or vector. # DEV: there's a wrinkle here. e = np.zeros((m, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2 - e[:, UNEQUAL_BOTH_HOM] = p_mutation**2 - e[:, BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation**2 - e[:, REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation) - e[:, REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation) - e[:, MISSING_INDEX] = 1 + e[:, core.EQUAL_BOTH_HOM] = (1 - p_mutation) ** 2 + e[:, core.UNEQUAL_BOTH_HOM] = p_mutation ** 2 + e[:, core.BOTH_HET] = (1 - p_mutation) ** 2 + p_mutation ** 2 + e[:, core.REF_HOM_OBS_HET] = 2 * p_mutation * (1 - p_mutation) + e[:, core.REF_HET_OBS_HOM] = p_mutation * (1 - p_mutation) + e[:, core.MISSING_INDEX] = 1 return e @@ -233,8 +238,7 @@ def forwards( norm=True, ): """ - Run the Li and Stephens forwards algorithm on haplotype or - unphased genotype data. + Run the Li & Stephens forwards algorithm on haplotype or unphased genotype data. """ n, m, ploidy = checks( reference_panel, @@ -281,8 +285,7 @@ def backwards( scale_mutation_based_on_n_alleles=True, ): """ - Run the Li and Stephens backwards algorithm on haplotype or - unphased genotype data. + Run the Li & Stephens backwards algorithm on haplotype or unphased genotype data. """ n, m, ploidy = checks( reference_panel, @@ -330,8 +333,7 @@ def viterbi( scale_mutation_based_on_n_alleles=True, ): """ - Run the Li and Stephens Viterbi algorithm on haplotype or - unphased genotype data. + Run the Li & Stephens Viterbi algorithm on haplotype or unphased genotype data. """ n, m, ploidy = checks( reference_panel, diff --git a/lshmm/core.py b/lshmm/core.py new file mode 100644 index 0000000..e4a701c --- /dev/null +++ b/lshmm/core.py @@ -0,0 +1,54 @@ +import numpy as np + +from lshmm import jit + + +EQUAL_BOTH_HOM = 4 +UNEQUAL_BOTH_HOM = 0 +BOTH_HET = 7 +REF_HOM_OBS_HET = 1 +REF_HET_OBS_HOM = 2 +MISSING_INDEX = 3 + +MISSING = -1 + + +""" Helper functions. """ +# https://github.com/numba/numba/issues/1269 +@jit.numba_njit +def np_apply_along_axis(func1d, axis, arr): + """ Create numpy-like functions for max, sum etc. """ + assert arr.ndim == 2 + assert axis in [0, 1] + if axis == 0: + result = np.empty(arr.shape[1]) + for i in range(len(result)): + result[i] = func1d(arr[:, i]) + else: + result = np.empty(arr.shape[0]) + for i in range(len(result)): + result[i] = func1d(arr[i, :]) + return result + +@jit.numba_njit +def np_amax(array, axis): + """Numba implementation of numpy vectorised maximum.""" + return np_apply_along_axis(np.amax, axis, array) + + +@jit.numba_njit +def np_sum(array, axis): + """Numba implementation of numpy vectorised sum.""" + return np_apply_along_axis(np.sum, axis, array) + + +@jit.numba_njit +def np_argmax(array, axis): + """ Numba implementation of numpy vectorised argmax. """ + return np_apply_along_axis(np.argmax, axis, array) + + +""" Functions used across different implementations of the LS HMM. """ +@jit.numba_njit +def get_index_in_emission_prob_matrix(a1, a2): + return np.int64(np.equal(a1, a2) or a2 == MISSING) diff --git a/lshmm/forward_backward/fb_diploid.py b/lshmm/forward_backward/fb_diploid.py index 50ffe12..45c033e 100644 --- a/lshmm/forward_backward/fb_diploid.py +++ b/lshmm/forward_backward/fb_diploid.py @@ -2,52 +2,9 @@ import numpy as np +from lshmm import core from lshmm import jit -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - - -# https://github.com/numba/numba/issues/1269 -@jit.numba_njit -def np_apply_along_axis(func1d, axis, arr): - """Create numpy-like functions for max, sum etc.""" - assert arr.ndim == 2 - assert axis in [0, 1] - if axis == 0: - result = np.empty(arr.shape[1]) - for i in range(len(result)): - result[i] = func1d(arr[:, i]) - else: - result = np.empty(arr.shape[0]) - for i in range(len(result)): - result[i] = func1d(arr[i, :]) - return result - - -@jit.numba_njit -def np_amax(array, axis): - """Numba implementation of numpy vectorised maximum.""" - return np_apply_along_axis(np.amax, axis, array) - - -@jit.numba_njit -def np_sum(array, axis): - """Numba implementation of numpy vectorised sum.""" - return np_apply_along_axis(np.sum, axis, array) - - -@jit.numba_njit -def np_argmax(array, axis): - """Numba implementation of numpy vectorised argmax.""" - return np_apply_along_axis(np.argmax, axis, array) - def forwards_ls_dip(n, m, G, s, e, r, norm=True): """Matrix based diploid LS forward algorithm using numpy vectorisation.""" @@ -57,8 +14,8 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): c = np.ones(m) r_n = r / n - if s[0, 0] == MISSING: - index = MISSING_INDEX * np.ones( + if s[0, 0] == core.MISSING: + index = core.MISSING_INDEX * np.ones( (n, n), dtype=np.int64 ) # We could have chosen anything here, this just implies a multiplication by a constant. else: @@ -76,8 +33,8 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): # Forwards for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) + 2 * ( G[l, :, :] == 1 @@ -93,7 +50,7 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): F[l, :, :] += (r_n[l]) ** 2 # One changes - sum_j = np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T + sum_j = core.np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T F[l, :, :] += ((1 - r[l]) * r_n[l]) * (sum_j + sum_j.T) # Emission @@ -105,8 +62,8 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): else: # Forwards for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) + 2 * ( G[l, :, :] == 1 @@ -122,7 +79,7 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): F[l, :, :] += (r_n[l]) ** 2 * np.sum(F[l - 1, :, :]) # One changes - sum_j = np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T + sum_j = core.np_sum(F[l - 1, :, :], 0).repeat(n).reshape((-1, n)).T # sum_j2 = np_sum(F[l - 1, :, :], 1).repeat(n).reshape((-1, n)) F[l, :, :] += ((1 - r[l]) * r_n[l]) * (sum_j + sum_j.T) @@ -135,7 +92,7 @@ def forwards_ls_dip(n, m, G, s, e, r, norm=True): def backwards_ls_dip(n, m, G, s, e, c, r): - """Matrix based diploid LS backward algorithm using numpy vectorisation.""" + """ Matrix based diploid LS backward algorithm using numpy vectorisation. """ # Initialise the backward tensor B = np.zeros((m, n, n)) @@ -145,8 +102,8 @@ def backwards_ls_dip(n, m, G, s, e, c, r): # Backwards for l in range(m - 2, -1, -1): - if s[0, l + 1] == MISSING: - index = MISSING_INDEX * np.ones( + if s[0, l + 1] == core.MISSING: + index = core.MISSING_INDEX * np.ones( (n, n), dtype=np.int64 ) # We could have chosen anything here, this just implies a multiplication by a constant. else: @@ -167,7 +124,7 @@ def backwards_ls_dip(n, m, G, s, e, c, r): ) # One changes - sum_j = np_sum(B[l + 1, :, :] * e[l + 1, index], 0).repeat(n).reshape((-1, n)) + sum_j = core.np_sum(B[l + 1, :, :] * e[l + 1, index], 0).repeat(n).reshape((-1, n)) B[l, :, :] += ((1 - r[l + 1]) * r_n[l + 1]) * (sum_j + sum_j.T) B[l, :, :] *= 1 / c[l + 1] @@ -176,15 +133,15 @@ def backwards_ls_dip(n, m, G, s, e, c, r): @jit.numba_njit def forward_ls_dip_starting_point(n, m, G, s, e, r): - """Naive implementation of LS diploid forwards algorithm.""" + """ Naive implementation of LS diploid forwards algorithm. """ # Initialise the forward tensor F = np.zeros((m, n, n)) r_n = r / n for j1 in range(n): for j2 in range(n): F[0, j1, j2] = 1 / (n**2) - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX + if s[0, 0] == core.MISSING: + index_tmp = core.MISSING_INDEX else: index_tmp = ( 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) @@ -231,24 +188,24 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): # What is the emission? - if s[0, l] == MISSING: - F[l, j1, j2] *= e[l, MISSING_INDEX] + if s[0, l] == core.MISSING: + F[l, j1, j2] *= e[l, core.MISSING_INDEX] else: if s[0, l] == 1: # OBS is het if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, BOTH_HET] + F[l, j1, j2] *= e[l, core.BOTH_HET] else: # REF is hom - F[l, j1, j2] *= e[l, REF_HOM_OBS_HET] + F[l, j1, j2] *= e[l, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, REF_HET_OBS_HOM] + F[l, j1, j2] *= e[l, core.REF_HET_OBS_HOM] else: # REF is hom if G[l, j1, j2] == s[0, l]: # Equal - F[l, j1, j2] *= e[l, EQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.EQUAL_BOTH_HOM] else: # Unequal - F[l, j1, j2] *= e[l, UNEQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.UNEQUAL_BOTH_HOM] ll = np.log10(np.sum(F[l, :, :])) @@ -257,7 +214,7 @@ def forward_ls_dip_starting_point(n, m, G, s, e, r): @jit.numba_njit def backward_ls_dip_starting_point(n, m, G, s, e, r): - """Naive implementation of LS diploid backwards algorithm.""" + """ Naive implementation of LS diploid backwards algorithm. """ # Backwards B = np.zeros((m, n, n)) @@ -274,8 +231,8 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): # Evaluate the emission matrix at this site, for all pairs e_tmp = np.zeros((n, n)) - if s[0, l + 1] == MISSING: - e_tmp[:, :] = e[l + 1, MISSING_INDEX] + if s[0, l + 1] == core.MISSING: + e_tmp[:, :] = e[l + 1, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -283,18 +240,18 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): if s[0, l + 1] == 1: # OBS is het if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, BOTH_HET] + e_tmp[j1, j2] = e[l + 1, core.BOTH_HET] else: # REF is hom - e_tmp[j1, j2] = e[l + 1, REF_HOM_OBS_HET] + e_tmp[j1, j2] = e[l + 1, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, REF_HET_OBS_HOM] + e_tmp[j1, j2] = e[l + 1, core.REF_HET_OBS_HOM] else: # REF is hom if G[l + 1, j1, j2] == s[0, l + 1]: # Equal - e_tmp[j1, j2] = e[l + 1, EQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.EQUAL_BOTH_HOM] else: # Unequal - e_tmp[j1, j2] = e[l + 1, UNEQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.UNEQUAL_BOTH_HOM] for j1 in range(n): for j2 in range(n): @@ -335,14 +292,14 @@ def backward_ls_dip_starting_point(n, m, G, s, e, r): @jit.numba_njit def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): - """LS diploid forwards algoritm without vectorisation.""" + """ LS diploid forwards algoritm without vectorisation. """ # Initialise the forward tensor F = np.zeros((m, n, n)) for j1 in range(n): for j2 in range(n): F[0, j1, j2] = 1 / (n**2) - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX + if s[0, 0] == core.MISSING: + index_tmp = core.MISSING_INDEX else: index_tmp = ( 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) @@ -375,8 +332,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): for j2 in range(n): F[l, j1, j2] += F_no_change[j1, j2] - if s[0, l] == MISSING: - F[l, :, :] *= e[l, MISSING_INDEX] + if s[0, l] == core.MISSING: + F[l, :, :] *= e[l, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -384,18 +341,18 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): if s[0, l] == 1: # OBS is het if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, BOTH_HET] + F[l, j1, j2] *= e[l, core.BOTH_HET] else: # REF is hom - F[l, j1, j2] *= e[l, REF_HOM_OBS_HET] + F[l, j1, j2] *= e[l, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, REF_HET_OBS_HOM] + F[l, j1, j2] *= e[l, core.REF_HET_OBS_HOM] else: # REF is hom if G[l, j1, j2] == s[0, l]: # Equal - F[l, j1, j2] *= e[l, EQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.EQUAL_BOTH_HOM] else: # Unequal - F[l, j1, j2] *= e[l, UNEQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.UNEQUAL_BOTH_HOM] c[l] = np.sum(F[l, :, :]) F[l, :, :] *= 1 / c[l] @@ -425,8 +382,8 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): for j2 in range(n): F[l, j1, j2] += F_no_change[j1, j2] - if s[0, l] == MISSING: - F[l, :, :] *= e[l, MISSING_INDEX] + if s[0, l] == core.MISSING: + F[l, :, :] *= e[l, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -434,18 +391,18 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): if s[0, l] == 1: # OBS is het if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, BOTH_HET] + F[l, j1, j2] *= e[l, core.BOTH_HET] else: # REF is hom - F[l, j1, j2] *= e[l, REF_HOM_OBS_HET] + F[l, j1, j2] *= e[l, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l, j1, j2] == 1: # REF is het - F[l, j1, j2] *= e[l, REF_HET_OBS_HOM] + F[l, j1, j2] *= e[l, core.REF_HET_OBS_HOM] else: # REF is hom if G[l, j1, j2] == s[0, l]: # Equal - F[l, j1, j2] *= e[l, EQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.EQUAL_BOTH_HOM] else: # Unequal - F[l, j1, j2] *= e[l, UNEQUAL_BOTH_HOM] + F[l, j1, j2] *= e[l, core.UNEQUAL_BOTH_HOM] ll = np.log10(np.sum(F[l, :, :])) @@ -454,7 +411,7 @@ def forward_ls_dip_loop(n, m, G, s, e, r, norm=True): @jit.numba_njit def backward_ls_dip_loop(n, m, G, s, e, c, r): - """LS diploid backwards algoritm without vectorisation.""" + """ LS diploid backwards algoritm without vectorisation. """ # Initialise the backward tensor B = np.zeros((m, n, n)) B[m - 1, :, :] = 1 @@ -469,8 +426,8 @@ def backward_ls_dip_loop(n, m, G, s, e, c, r): # Evaluate the emission matrix at this site, for all pairs e_tmp = np.zeros((n, n)) - if s[0, l + 1] == MISSING: - e_tmp[:, :] = e[l + 1, MISSING_INDEX] + if s[0, l + 1] == core.MISSING: + e_tmp[:, :] = e[l + 1, core.MISSING_INDEX] else: for j1 in range(n): for j2 in range(n): @@ -479,18 +436,18 @@ def backward_ls_dip_loop(n, m, G, s, e, c, r): if s[0, l + 1] == 1: # OBS is het if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, BOTH_HET] + e_tmp[j1, j2] = e[l + 1, core.BOTH_HET] else: # REF is hom - e_tmp[j1, j2] = e[l + 1, REF_HOM_OBS_HET] + e_tmp[j1, j2] = e[l + 1, core.REF_HOM_OBS_HET] else: # OBS is hom if G[l + 1, j1, j2] == 1: # REF is het - e_tmp[j1, j2] = e[l + 1, REF_HET_OBS_HOM] + e_tmp[j1, j2] = e[l + 1, core.REF_HET_OBS_HOM] else: # REF is hom if G[l + 1, j1, j2] == s[0, l + 1]: # Equal - e_tmp[j1, j2] = e[l + 1, EQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.EQUAL_BOTH_HOM] else: # Unequal - e_tmp[j1, j2] = e[l + 1, UNEQUAL_BOTH_HOM] + e_tmp[j1, j2] = e[l + 1, core.UNEQUAL_BOTH_HOM] for j1 in range(n): for j2 in range(n): diff --git a/lshmm/forward_backward/fb_haploid.py b/lshmm/forward_backward/fb_haploid.py index 69d01fc..b3c111f 100644 --- a/lshmm/forward_backward/fb_haploid.py +++ b/lshmm/forward_backward/fb_haploid.py @@ -2,24 +2,21 @@ import numpy as np +from lshmm import core from lshmm import jit -MISSING = -1 - @jit.numba_njit def forwards_ls_hap(n, m, H, s, e, r, norm=True): - """Matrix based haploid LS forward algorithm using numpy vectorisation.""" - # Initialise + """ Matrix based haploid LS forward algorithm using numpy vectorisation. """ F = np.zeros((m, n)) r_n = r / n if norm: c = np.zeros(m) for i in range(n): - F[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - ) + emission_idx = core.get_index_in_emission_prob_matrix(H[0, i], s[0, 0]) + F[0, i] = 1 / n * e[0, emission_idx] c[0] += F[0, i] for i in range(n): @@ -29,9 +26,8 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l] - F[l, i] *= e[ - l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING) - ] + emission_idx = core.get_index_in_emission_prob_matrix(H[l, i], s[0, l]) + F[l, i] *= e[l, emission_idx] c[l] += F[l, i] for i in range(n): @@ -43,17 +39,15 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): c = np.ones(m) for i in range(n): - F[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - ) + emission_idx = core.get_index_in_emission_prob_matrix(H[0, i], s[0, 0]) + F[0, i] = 1 / n * e[0, emission_idx] # Forwards pass for l in range(1, m): for i in range(n): F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l] - F[l, i] *= e[ - l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING) - ] + emission_idx = core.get_index_in_emission_prob_matrix(H[l, i], s[0, l]) + F[l, i] *= e[l, emission_idx] ll = np.log10(np.sum(F[m - 1, :])) @@ -62,8 +56,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True): @jit.numba_njit def backwards_ls_hap(n, m, H, s, e, c, r): - """Matrix based haploid LS backward algorithm using numpy vectorisation.""" - # Initialise + """ Matrix based haploid LS backward algorithm using numpy vectorisation. """ B = np.zeros((m, n)) for i in range(n): B[m - 1, i] = 1 @@ -74,15 +67,8 @@ def backwards_ls_hap(n, m, H, s, e, c, r): tmp_B = np.zeros(n) tmp_B_sum = 0 for i in range(n): - tmp_B[i] = ( - e[ - l + 1, - np.int64( - np.equal(H[l + 1, i], s[0, l + 1]) or s[0, l + 1] == MISSING - ), - ] - * B[l + 1, i] - ) + emission_idx = core.get_index_in_emission_prob_matrix(H[l + 1, i], s[0, l + 1]) + tmp_B[i] = e[l + 1, emission_idx] * B[l + 1, i] tmp_B_sum += tmp_B[i] for i in range(n): B[l, i] = r_n[l + 1] * tmp_B_sum diff --git a/lshmm/vit_diploid.py b/lshmm/vit_diploid.py index 316b5d0..d82e6ca 100644 --- a/lshmm/vit_diploid.py +++ b/lshmm/vit_diploid.py @@ -2,50 +2,13 @@ import numpy as np +from . import core from . import jit -MISSING = -1 -MISSING_INDEX = 3 - - -# https://github.com/numba/numba/issues/1269 -@jit.numba_njit -def np_apply_along_axis(func1d, axis, arr): - """Create numpy-like functions for max, sum etc.""" - assert arr.ndim == 2 - assert axis in [0, 1] - if axis == 0: - result = np.empty(arr.shape[1]) - for i in range(len(result)): - result[i] = func1d(arr[:, i]) - else: - result = np.empty(arr.shape[0]) - for i in range(len(result)): - result[i] = func1d(arr[i, :]) - return result - - -@jit.numba_njit -def np_amax(array, axis): - """Numba implementation of numpy vectorised maximum.""" - return np_apply_along_axis(np.amax, axis, array) - - -@jit.numba_njit -def np_sum(array, axis): - """Numba implementation of numpy vectorised sum.""" - return np_apply_along_axis(np.sum, axis, array) - - -@jit.numba_njit -def np_argmax(array, axis): - """Numba implementation of numpy vectorised argmax.""" - return np_apply_along_axis(np.argmax, axis, array) - @jit.numba_njit def forwards_viterbi_dip_naive(n, m, G, s, e, r): - """Naive implementation of LS diploid Viterbi algorithm.""" + """ Naive implementation of LS diploid Viterbi algorithm. """ # Initialise V = np.zeros((m, n, n)) P = np.zeros((m, n, n)).astype(np.int64) @@ -54,8 +17,8 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX + if s[0, 0] == core.MISSING: + index_tmp = core.MISSING_INDEX else: index_tmp = ( 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) @@ -65,8 +28,8 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r): V[0, j1, j2] = 1 / (n**2) * e[0, index_tmp] for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -101,7 +64,7 @@ def forwards_viterbi_dip_naive(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): - """Naive implementation of LS diploid Viterbi algorithm, with reduced memory.""" + """ Naive implementation of LS diploid Viterbi algorithm, with reduced memory. """ # Initialise V = np.zeros((n, n)) V_previous = np.zeros((n, n)) @@ -111,8 +74,8 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX + if s[0, 0] == core.MISSING: + index_tmp = core.MISSING_INDEX else: index_tmp = ( 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) @@ -121,11 +84,12 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): ) V_previous[j1, j2] = 1 / (n**2) * e[0, index_tmp] - # Take a look at Haploid Viterbi implementation in Jeromes code and see if we can pinch some ideas. + # Take a look at the haploid Viterbi implementation in Jerome's code, and + # see if we can pinch some ideas. # Diploid Viterbi, with smaller memory footprint. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -159,7 +123,7 @@ def forwards_viterbi_dip_naive_low_mem(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): - """LS diploid Viterbi algorithm, with reduced memory.""" + """ LS diploid Viterbi algorithm, with reduced memory. """ # Initialise V = np.zeros((n, n)) V_previous = np.zeros((n, n)) @@ -169,8 +133,8 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX + if s[0, 0] == core.MISSING: + index_tmp = core.MISSING_INDEX else: index_tmp = ( 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) @@ -181,8 +145,8 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -194,8 +158,8 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): argmax = np.argmax(V_previous) V_previous *= 1 / c[l] - V_rowcol_max = np_amax(V_previous, 0) - arg_rowcol_max = np_argmax(V_previous, 0) + V_rowcol_max = core.np_amax(V_previous, 0) + arg_rowcol_max = core.np_argmax(V_previous, 0) no_switch = (1 - r[l]) ** 2 + 2 * (r_n[l] * (1 - r[l])) + r_n[l] ** 2 single_switch = r_n[l] * (1 - r[l]) + r_n[l] ** 2 @@ -242,7 +206,7 @@ def forwards_viterbi_dip_low_mem(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): - """LS diploid Viterbi algorithm, with reduced memory.""" + """ LS diploid Viterbi algorithm, with reduced memory. """ # Initialise V = np.zeros((n, n)) V_previous = np.zeros((n, n)) @@ -262,8 +226,8 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX + if s[0, 0] == core.MISSING: + index_tmp = core.MISSING_INDEX else: index_tmp = ( 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) @@ -274,8 +238,8 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): # Diploid Viterbi, with smaller memory footprint, rescaling, and using the structure of the HMM. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -288,9 +252,9 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): V_argmaxes[l - 1] = argmax # added V_previous *= 1 / c[l] - V_rowcol_max = np_amax(V_previous, 0) + V_rowcol_max = core.np_amax(V_previous, 0) V_rowcol_maxes[l - 1, :] = V_rowcol_max - arg_rowcol_max = np_argmax(V_previous, 0) + arg_rowcol_max = core.np_argmax(V_previous, 0) V_rowcol_argmaxes[l - 1, :] = arg_rowcol_max no_switch = (1 - r[l]) ** 2 + 2 * (r_n[l] * (1 - r[l])) + r_n[l] ** 2 @@ -322,8 +286,8 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): V_previous = np.copy(V) V_argmaxes[m - 1] = np.argmax(V_previous) - V_rowcol_maxes[m - 1, :] = np_amax(V_previous, 0) - V_rowcol_argmaxes[m - 1, :] = np_argmax(V_previous, 0) + V_rowcol_maxes[m - 1, :] = core.np_amax(V_previous, 0) + V_rowcol_argmaxes[m - 1, :] = core.np_argmax(V_previous, 0) ll = np.sum(np.log10(c)) + np.log10(np.amax(V)) return ( @@ -339,7 +303,7 @@ def forwards_viterbi_dip_low_mem_no_pointer(n, m, G, s, e, r): @jit.numba_njit def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): - """Vectorised LS diploid Viterbi algorithm using numpy.""" + """ Vectorised LS diploid Viterbi algorithm using numpy. """ # Initialise V = np.zeros((m, n, n)) P = np.zeros((m, n, n)).astype(np.int64) @@ -348,8 +312,8 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): for j1 in range(n): for j2 in range(n): - if s[0, 0] == MISSING: - index_tmp = MISSING_INDEX + if s[0, 0] == core.MISSING: + index_tmp = core.MISSING_INDEX else: index_tmp = ( 4 * np.int64(np.equal(G[0, j1, j2], s[0, 0])) @@ -360,8 +324,8 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): # Jumped the gun - vectorising. for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -388,7 +352,7 @@ def forwards_viterbi_dip_naive_vec(n, m, G, s, e, r): def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): - """Fully vectorised naive LS diploid Viterbi algorithm using numpy.""" + """ Fully vectorised naive LS diploid Viterbi algorithm using numpy. """ char_both = np.eye(n * n).ravel().reshape((n, n, n, n)) char_col = np.tile(np.sum(np.eye(n * n).reshape((n, n, n, n)), 3), (n, 1, 1, 1)) char_row = np.copy(char_col).T @@ -398,8 +362,8 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): V = np.zeros((m, n, n)) P = np.zeros((m, n, n)).astype(np.int64) c = np.ones(m) - if s[0, 0] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, 0] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[0, :, :], s[0, 0]).astype(np.int64) @@ -410,8 +374,8 @@ def forwards_viterbi_dip_naive_full_vec(n, m, G, s, e, r): r_n = r / n for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX * np.ones((n, n), dtype=np.int64) + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX * np.ones((n, n), dtype=np.int64) else: index = ( 4 * np.equal(G[l, :, :], s[0, l]).astype(np.int64) @@ -455,8 +419,7 @@ def in_list(array, value): where = np.searchsorted(array, value) if where < array.shape[0]: return array[where] == value - else: - return False + return False @jit.numba_njit @@ -469,7 +432,7 @@ def backwards_viterbi_dip_no_pointer( recombs_double, V_last, ): - """Run a backwards pass to determine the most likely path.""" + """ Run a backwards pass to determine the most likely path. """ assert V_last.ndim == 2 assert V_last.shape[0] == V_last.shape[1] # Initialisation @@ -503,8 +466,8 @@ def get_phased_path(n, path): @jit.numba_njit def path_ll_dip(n, m, G, phased_path, s, e, r): """Evaluate log-likelihood path through a reference panel which results in sequence s.""" - if s[0, 0] == MISSING: - index = MISSING_INDEX + if s[0, 0] == core.MISSING: + index = core.MISSING_INDEX else: index = ( 4 * np.int64(np.equal(G[0, phased_path[0][0], phased_path[1][0]], s[0, 0])) @@ -516,8 +479,8 @@ def path_ll_dip(n, m, G, phased_path, s, e, r): r_n = r / n for l in range(1, m): - if s[0, l] == MISSING: - index = MISSING_INDEX + if s[0, l] == core.MISSING: + index = core.MISSING_INDEX else: index = ( 4 diff --git a/lshmm/vit_haploid.py b/lshmm/vit_haploid.py index 7fec45e..08a6e47 100644 --- a/lshmm/vit_haploid.py +++ b/lshmm/vit_haploid.py @@ -2,45 +2,42 @@ import numpy as np +from . import core from . import jit -MISSING = -1 - @jit.numba_njit def viterbi_naive_init(n, m, H, s, e, r): - """Initialise naive implementation of LS viterbi.""" + """ Initialise a naive implementation of the LS Viterbi algorithm. """ V = np.zeros((m, n)) P = np.zeros((m, n)).astype(np.int64) r_n = r / n + for i in range(n): - V[0, i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - ) + emission_idx = core.get_index_in_emission_prob_matrix(H[0, i], s[0, 0]) + V[0, i] = 1 / n * e[0, emission_idx] return V, P, r_n @jit.numba_njit def viterbi_init(n, m, H, s, e, r): - """Initialise naive, but more space memory efficient implementation of LS viterbi.""" + """ Initialise a naive, but more space memory efficient, implementation of the LS Viterbi algorithm. """ V_previous = np.zeros(n) V = np.zeros(n) P = np.zeros((m, n)).astype(np.int64) r_n = r / n for i in range(n): - V_previous[i] = ( - 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] - ) + emission_idx = core.get_index_in_emission_prob_matrix(H[0, i], s[0, 0]) + V_previous[i] = 1 / n * e[0, emission_idx] return V, V_previous, P, r_n @jit.numba_njit def forwards_viterbi_hap_naive(n, m, H, s, e, r): - """Naive implementation of LS haploid Viterbi algorithm.""" - # Initialise + """ Naive implementation of the haploid LS Viterbi algorithm. """ V, P, r_n = viterbi_naive_init(n, m, H, s, e, r) for j in range(1, m): @@ -48,10 +45,8 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r): # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V[j - 1, k] - ) + emission_idx = core.get_index_in_emission_prob_matrix(H[j, i], s[0, j]) + v[k] = e[j, emission_idx] * V[j - 1, k] if k == i: v[k] *= 1 - r[j] + r_n[j] else: @@ -66,8 +61,7 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): - """Naive matrix based implementation of LS haploid forward Viterbi algorithm using numpy.""" - # Initialise + """ Naive matrix based implementation of LS haploid forward Viterbi algorithm using numpy. """ V, P, r_n = viterbi_naive_init(n, m, H, s, e, r) for j in range(1, m): @@ -75,7 +69,8 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): for i in range(n): v = np.copy(v_tmp) v[i] += V[j - 1, i] * (1 - r[j]) - v *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix(H[j, i], s[0, j]) + v *= e[j, emission_idx] P[j, i] = np.argmax(v) V[j, i] = v[P[j, i]] @@ -86,8 +81,7 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): - """Naive implementation of LS haploid Viterbi algorithm, with reduced memory.""" - # Initialise + """ Naive implementation of LS haploid Viterbi algorithm, with reduced memory. """ V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r) for j in range(1, m): @@ -95,10 +89,8 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V_previous[k] - ) + emission_idx = core.get_index_in_emission_prob_matrix(H[j, i], s[0, j]) + v[k] = V_previous[k] * e[j, emission_idx] if k == i: v[k] *= 1 - r[j] + r_n[j] else: @@ -114,8 +106,7 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): - """Naive implementation of LS haploid Viterbi algorithm, with reduced memory and rescaling.""" - # Initialise + """ Naive implementation of LS haploid Viterbi algorithm, with reduced memory and rescaling. """ V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r) c = np.ones(m) @@ -126,10 +117,8 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): # Get the vector to maximise over v = np.zeros(n) for k in range(n): - v[k] = ( - e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] - * V_previous[k] - ) + emission_idx = core.get_index_in_emission_prob_matrix(H[j, i], s[0, j]) + v[k] = V_previous[k] * e[j, emission_idx] if k == i: v[k] *= 1 - r[j] + r_n[j] else: @@ -146,8 +135,7 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): - """LS haploid Viterbi algorithm, with reduced memory and exploits the Markov process structure.""" - # Initialise + """ LS haploid Viterbi algorithm, with reduced memory and exploits the Markov process structure. """ V, V_previous, P, r_n = viterbi_init(n, m, H, s, e, r) c = np.ones(m) @@ -162,7 +150,8 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix(H[j, i], s[0, j]) + V[i] *= e[j, emission_idx] V_previous = np.copy(V) ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -172,11 +161,11 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): - """LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure.""" - # Initialise + """ LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure. """ V = np.zeros(n) for i in range(n): - V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix(H[0, i], s[0, 0]) + V[i] = 1 / n * e[0, emission_idx] P = np.zeros((m, n)).astype(np.int64) r_n = r / n c = np.ones(m) @@ -191,7 +180,8 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): if V[i] < r_n[j]: V[i] = r_n[j] P[j, i] = argmax - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix(H[j, i], s[0, j]) + V[i] *= e[j, emission_idx] ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -200,11 +190,11 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r): @jit.numba_njit def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): - """LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure.""" - # Initialise + """ LS haploid Viterbi algorithm with even smaller memory footprint and exploits the Markov process structure. """ V = np.zeros(n) for i in range(n): - V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix(H[0, i], s[0, 0]) + V[i] = 1 / n * e[0, emission_idx] r_n = r / n c = np.ones(m) recombs = [ @@ -225,7 +215,8 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): recombs[j] = np.append( recombs[j], i ) # We add template i as a potential template to recombine to at site j. - V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)] + emission_idx = core.get_index_in_emission_prob_matrix(H[j, i], s[0, j]) + V[i] *= e[j, emission_idx] V_argmaxes[m - 1] = np.argmax(V) ll = np.sum(np.log10(c)) + np.log10(np.max(V)) @@ -236,8 +227,7 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r): # Speedier version, variants x samples @jit.numba_njit def backwards_viterbi_hap(m, V_last, P): - """Run a backwards pass to determine the most likely path.""" - # Initialise + """ Run a backwards pass to determine the most likely path. """ assert len(V_last.shape) == 1 path = np.zeros(m).astype(np.int64) path[m - 1] = np.argmax(V_last) @@ -250,8 +240,7 @@ def backwards_viterbi_hap(m, V_last, P): @jit.numba_njit def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs): - """Run a backwards pass to determine the most likely path.""" - # Initialise + """ Run a backwards pass to determine the most likely path. """ path = np.zeros(m).astype(np.int64) path[m - 1] = V_argmaxes[m - 1] @@ -266,14 +255,14 @@ def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs): @jit.numba_njit def path_ll_hap(n, m, H, path, s, e, r): - """Evaluate log-likelihood path through a reference panel which results in sequence s.""" - index = np.int64(np.equal(H[0, path[0]], s[0, 0]) or s[0, 0] == MISSING) - log_prob_path = np.log10((1 / n) * e[0, index]) + """ Evaluate log-likelihood path through a reference panel which results in sequence s. """ + emission_idx = core.get_index_in_emission_prob_matrix(H[0, path[0]], s[0, 0]) + log_prob_path = np.log10((1 / n) * e[0, emission_idx]) old = path[0] r_n = r / n for l in range(1, m): - index = np.int64(np.equal(H[l, path[l]], s[0, l]) or s[0, l] == MISSING) + emission_idx = core.get_index_in_emission_prob_matrix(H[l, path[l]], s[0, l]) current = path[l] same = old == current @@ -282,7 +271,7 @@ def path_ll_hap(n, m, H, path, s, e, r): else: log_prob_path += np.log10(r_n[l]) - log_prob_path += np.log10(e[l, index]) + log_prob_path += np.log10(e[l, emission_idx]) old = current return log_prob_path diff --git a/tests/test_API.py b/tests/test_API.py index 129e67a..cdf5595 100644 --- a/tests/test_API.py +++ b/tests/test_API.py @@ -1,27 +1,18 @@ -# Simulation import itertools +import pytest -# Python libraries -import msprime import numpy as np -import pytest + +import msprime import tskit import lshmm as ls +import lshmm.core as core import lshmm.forward_backward.fb_diploid as fbd import lshmm.forward_backward.fb_haploid as fbh import lshmm.vit_diploid as vd import lshmm.vit_haploid as vh -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - class LSBase: """Superclass of Li and Stephens tests.""" @@ -33,13 +24,13 @@ def example_haplotypes(self, ts, seed=42): haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING haplotypes.append(s_tmp) return H, haplotypes @@ -68,18 +59,18 @@ def haplotype_emission(self, mu, m, n_alleles, scale_mutation_based_on_n_alleles def genotype_emission(self, mu, m): # Define the emission probability matrix e = np.zeros((m, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2 - e[:, UNEQUAL_BOTH_HOM] = mu**2 - e[:, BOTH_HET] = (1 - mu) ** 2 + mu**2 - e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu) - e[:, REF_HET_OBS_HOM] = mu * (1 - mu) - e[:, MISSING_INDEX] = 1 - + e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2 + e[:, core.UNEQUAL_BOTH_HOM] = mu**2 + e[:, core.BOTH_HET] = (1 - mu) ** 2 + mu**2 + e[:, core.REF_HOM_OBS_HET] = 2 * mu * (1 - mu) + e[:, core.REF_HET_OBS_HOM] = mu * (1 - mu) + e[:, core.MISSING_INDEX] = 1 return e def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): - """Returns an iterator over combinations of haplotype, recombination and - mutation probabilities.""" + """ + Returns an iterator over combinations of haplotype, recombination and mutation probabilities. + """ np.random.seed(seed) H, haplotypes = self.example_haplotypes(ts) n = H.shape[1] @@ -89,7 +80,7 @@ def _get_num_alleles(ref_haps, query): assert ref_haps.shape[0] == query.shape[1] num_sites = ref_haps.shape[0] num_alleles = np.zeros(num_sites, dtype=np.int8) - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) for i in range(num_sites): uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) @@ -132,13 +123,13 @@ def example_genotypes(self, ts, seed=42): ] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING genotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING genotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING genotypes.append(s_tmp) m = ts.get_num_sites() @@ -189,17 +180,14 @@ def example_parameters_genotypes_larger( r = mean_r * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) r[0] = 0 - # Error probability mu = mean_mu * np.ones(m) * ((np.random.rand(m) + 0.5) / 2) - # Define the emission probability matrix e = self.genotype_emission(mu, m) for s in genotypes: yield n, m, G, s, e, r, mu def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" assert np.allclose(A, B, rtol=1e-9, atol=0.0) # Define a bunch of very small tree-sequences for testing a collection of parameters on @@ -240,7 +228,7 @@ class FBAlgorithmBase(LSBase): class TestMethodsHap(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" + """ Test that the computed likelihood is the same across all implementations. """ def verify(self, ts): for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): @@ -257,7 +245,7 @@ def verify(self, ts): class TestMethodsDip(FBAlgorithmBase): - """Test that we compute the sample likelihoods across all implementations.""" + """ Test that the computed likelihood is the same across all implementations. """ def verify(self, ts): for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts): @@ -273,11 +261,11 @@ def verify(self, ts): class VitAlgorithmBase(LSBase): - """Base for viterbi algoritm tests.""" + """ Base for Viterbi algoritm tests. """ class TestViterbiHap(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" + """ Test that the computed log-likelihood is the same across all implementations. """ def verify(self, ts): for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes(ts): @@ -292,7 +280,7 @@ def verify(self, ts): class TestViterbiDip(VitAlgorithmBase): - """Test that we have the same log-likelihood across all implementations""" + """ Test that the computed log-likelihood is the same across all implementations. """ def verify(self, ts): for n, m, G_vs, s, e_vs, r, mu in self.example_parameters_genotypes(ts): diff --git a/tests/test_API_multiallelic.py b/tests/test_API_multiallelic.py index 92f1ab3..815a278 100644 --- a/tests/test_API_multiallelic.py +++ b/tests/test_API_multiallelic.py @@ -1,27 +1,17 @@ -# Simulation import itertools -# Python libraries import msprime import numpy as np import pytest import tskit import lshmm as ls +import lshmm.core as core import lshmm.forward_backward.fb_diploid as fbd import lshmm.forward_backward.fb_haploid as fbh import lshmm.vit_diploid as vd import lshmm.vit_haploid as vh -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - class LSBase: """Superclass of Li and Stephens tests.""" @@ -33,13 +23,13 @@ def example_haplotypes(self, ts, num_random=10, seed=42): haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING haplotypes.append(s_tmp) return H, haplotypes @@ -77,7 +67,7 @@ def _get_num_alleles(ref_haps, query): assert ref_haps.shape[0] == query.shape[1] num_sites = ref_haps.shape[0] num_alleles = np.zeros(num_sites, dtype=np.int8) - exclusion_set = np.array([MISSING]) + exclusion_set = np.array([core.MISSING]) for i in range(num_sites): uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) @@ -115,7 +105,6 @@ def _get_num_alleles(ref_haps, query): yield n, m, H, s, e, r, mu def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" assert np.allclose(A, B, rtol=1e-9, atol=0.0) # Define a bunch of very small tree-sequences for testing a collection of parameters on @@ -187,7 +176,6 @@ def verify(self, ts): B = ls.backwards(H_vs, s, c, r, p_mutation=mu) self.assertAllClose(F, F_vs) self.assertAllClose(B, B_vs) - # print(e_vs) self.assertAllClose(ll_vs, ll) for n, m, H_vs, s, e_vs, r, mu in self.example_parameters_haplotypes( diff --git a/tests/test_LS_haploid_diploid.py b/tests/test_LS_haploid_diploid.py index 9b9f7d8..ae4ee9a 100644 --- a/tests/test_LS_haploid_diploid.py +++ b/tests/test_LS_haploid_diploid.py @@ -1,28 +1,18 @@ -# Simulation import itertools +import pytest -# Python libraries -import msprime import numpy as np -import pytest +import numba as nb + +import msprime +import tskit +import lshmm.core as core import lshmm.forward_backward.fb_diploid as fbd import lshmm.forward_backward.fb_haploid as fbh import lshmm.vit_diploid as vd import lshmm.vit_haploid as vh -EQUAL_BOTH_HOM = 4 -UNEQUAL_BOTH_HOM = 0 -BOTH_HET = 7 -REF_HOM_OBS_HET = 1 -REF_HET_OBS_HOM = 2 - -MISSING = -1 -MISSING_INDEX = 3 - -import numba as nb -import tskit - class LSBase: """Superclass of Li and Stephens tests.""" @@ -34,13 +24,13 @@ def example_haplotypes(self, ts): haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING haplotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING haplotypes.append(s_tmp) return H, haplotypes @@ -50,19 +40,17 @@ def haplotype_emission(self, mu, m): e = np.zeros((m, 2)) e[:, 0] = mu # If they match e[:, 1] = 1 - mu # If they don't match - return e def genotype_emission(self, mu, m): # Define the emission probability matrix e = np.zeros((m, 8)) - e[:, EQUAL_BOTH_HOM] = (1 - mu) ** 2 - e[:, UNEQUAL_BOTH_HOM] = mu**2 - e[:, BOTH_HET] = 1 - mu - e[:, REF_HOM_OBS_HET] = 2 * mu * (1 - mu) - e[:, REF_HET_OBS_HOM] = mu * (1 - mu) - e[:, MISSING_INDEX] = 1 - + e[:, core.EQUAL_BOTH_HOM] = (1 - mu) ** 2 + e[:, core.UNEQUAL_BOTH_HOM] = mu**2 + e[:, core.BOTH_HET] = 1 - mu + e[:, core.REF_HOM_OBS_HET] = 2 * mu * (1 - mu) + e[:, core.REF_HET_OBS_HOM] = mu * (1 - mu) + e[:, core.MISSING_INDEX] = 1 return e def example_parameters_haplotypes(self, ts, seed=42): @@ -124,13 +112,13 @@ def example_genotypes(self, ts, seed=42): ] s_tmp = s.copy() - s_tmp[0, -1] = MISSING + s_tmp[0, -1] = core.MISSING genotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, ts.num_sites // 2] = MISSING + s_tmp[0, ts.num_sites // 2] = core.MISSING genotypes.append(s_tmp) s_tmp = s.copy() - s_tmp[0, :] = MISSING + s_tmp[0, :] = core.MISSING genotypes.append(s_tmp) m = ts.get_num_sites()