From e0a0674c7a1be187af65fd44cd055261c0fa2733 Mon Sep 17 00:00:00 2001 From: szhan Date: Mon, 17 Jun 2024 15:54:20 +0100 Subject: [PATCH] Reorganise functions in core --- lshmm/core.py | 162 +++++++++++++++++++++++++++++--------------------- 1 file changed, 93 insertions(+), 69 deletions(-) diff --git a/lshmm/core.py b/lshmm/core.py index c8c3e75..60d17b6 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -14,9 +14,7 @@ NONCOPY = -2 -""" Helper functions. """ - - +# Helper functions. # https://github.com/numba/numba/issues/1269 @jit.numba_njit def np_apply_along_axis(func1d, axis, arr): @@ -52,9 +50,7 @@ def np_argmax(array, axis): return np_apply_along_axis(np.argmax, axis, array) -""" Functions used across different implementations of the LS HMM. """ - - +# Functions used across different implementations of LS HMM. """ def convert_haplotypes_to_phased_genotypes(ref_panel): """ Convert a set of haplotypes into a matrix of diploid genotypes encoded as allele dosages, @@ -68,19 +64,21 @@ def convert_haplotypes_to_phased_genotypes(ref_panel): at a site, then the genotype at the site is assigned NONCOPY. The input reference haplotypes is of size (m, n), and the output genotypes is of size (m, n, n), - where m = number of sites and n = number of reference haplotypes. + where: + m = number of sites. + n = number of reference haplotypes. :param numpy.ndarray ref_panel: An array of reference haplotypes. :return: An array of reference genotypes. :rtype: numpy.ndarray """ - ALLOWED_ALLELE_STATES = np.array([0, 1, NONCOPY], dtype=np.int32) + ALLOWED_ALLELE_STATES = np.array([0, 1, NONCOPY], dtype=np.int8) assert np.all( np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES) ), f"Reference haplotypes contain illegal allele states." num_sites = ref_panel.shape[0] num_haps = ref_panel.shape[1] - genotypes = np.zeros((num_sites, num_haps, num_haps), dtype=np.int32) - np.inf + genotypes = np.zeros((num_sites, num_haps, num_haps), dtype=np.int8) - np.inf for i in range(num_sites): site_alleles = ref_panel[i, :] genotypes[i, :, :] = np.add.outer(site_alleles, site_alleles) @@ -108,14 +106,14 @@ def convert_haplotypes_to_unphased_genotypes(query): :return: An array of query genotypes. :rtype: numpy.ndarray """ - ALLOWED_ALLELE_STATES = np.array([0, 1, MISSING], dtype=np.int32) + ALLOWED_ALLELE_STATES = np.array([0, 1, MISSING], dtype=np.int8) assert np.all( np.isin(np.unique(query), ALLOWED_ALLELE_STATES) ), f"Query haplotypes contain illegal allele states." num_sites = query.shape[1] num_haps = query.shape[0] assert num_haps == 2, "Two haplotypes are expected in a diploid query." - genotypes = np.zeros((1, num_sites), dtype=np.int32) - np.inf + genotypes = np.zeros((1, num_sites), dtype=np.int8) - np.inf genotypes[0, :] = np.sum(query, axis=0) genotypes[0, np.any(query == MISSING, axis=0)] = MISSING return genotypes @@ -124,14 +122,14 @@ def convert_haplotypes_to_unphased_genotypes(query): def check_genotype_matrix(genotype_matrix, num_sample_haps): """ Check that at each site the number of non-NONCOPY values in the reference panel - in the form of a genotype matrix is at most a maximum. + in the form of a genotype matrix is at most a maximum value. The genotype matrix is an array of size (m, n), where: m = number of sites. n = number of haplotypes (sample and ancestor) in the reference panel. - The maximum is equal to (2n - 1), where n is the number of sample haplotypes + The maximum value is equal to (2n - 1), where n is the number of sample haplotypes in the genotype matrix, when a marginal tree is fully binary. :param numpy.ndarray genotype_matrix: An array containing the reference haplotypes. @@ -201,51 +199,14 @@ def check_alleles(alleles, num_sites): return np.int8([len(alleles) for _ in range(num_sites)]) # Otherwise, process allele lists. exclusion_set = np.array([MISSING, NONCOPY]) - num_alleles = np.zeros(num_sites, dtype=np.int32) + num_alleles = np.zeros(num_sites, dtype=np.int8) for i in range(num_sites): uniq_alleles = np.unique(alleles[i]) num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) return num_alleles -@jit.numba_njit -def get_index_in_emission_matrix_haploid(ref_allele, query_allele): - is_allele_match = ref_allele == query_allele - is_query_missing = query_allele == MISSING - if is_allele_match or is_query_missing: - return 1 - return 0 - - -@jit.numba_njit -def get_index_in_emission_matrix_diploid(ref_genotype, query_genotype): - """ - Compare the implied unphased genotypes (allele dosages) of - the reference and query to get the index of the entry - in the emission probability matrix, and return the index. - """ - if query_genotype == MISSING: - return MISSING_INDEX - else: - is_match = ref_genotype == query_genotype - is_ref_het = ref_genotype == 1 - is_query_het = query_genotype == 1 - return 4 * is_match + 2 * is_ref_het + is_query_het - - -@jit.numba_njit -def get_index_in_emission_matrix_diploid_genotypes( - ref_genotypes, query_genotype, num_ref_haps -): - if query_genotype == MISSING: - return MISSING_INDEX * np.ones((num_ref_haps, num_ref_haps), dtype=np.int64) - else: - is_match = ref_genotypes == query_genotype - is_ref_het = ref_genotypes == 1 - is_query_het = query_genotype == 1 - return 4 * is_match + 2 * is_ref_het + is_query_het - - +# Functions to assign emission probabilities for haploid LS HMM. @jit.numba_njit def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate): """ @@ -266,10 +227,10 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate) probability is the probability of mutation **any given one** of the alleles. The overall mutation probability is then (num_alleles - 1) * mutation probability. - :param float/numpy.ndarray(dtype=np.float64) mu: Probability of mutation. + :param float/numpy.ndarray mu: Mutation probability. :param int num_sites: Number of sites. - :param numpy.ndarray(dtype=np.int8): Number of distinct alleles per site. - :param bool scale_mutation_rate: Scale mutation rate based on the number of alleles if True (default). + :param numpy.ndarray: Number of distinct alleles per site. + :param bool scale_mutation_rate: Scale mutation rate based on the number of alleles. """ assert len(mu) == len( num_alleles @@ -294,8 +255,57 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate) return emission_matrix +@jit.numba_njit +def get_emission_probability_haploid(ref_allele, query_allele, site, emission_matrix): + if ref_allele == NONCOPY: + return 0.0 + else: + emission_index = get_index_in_emission_matrix_haploid(ref_allele, query_allele) + return emission_matrix[site, emission_index] + + +@jit.numba_njit +def get_index_in_emission_matrix_haploid(ref_allele, query_allele): + is_allele_match = ref_allele == query_allele + is_query_missing = query_allele == MISSING + if is_allele_match or is_query_missing: + return 1 + return 0 + + +# Functions to assign emission probabilities for diploid LS HMM. @jit.numba_njit def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate): + """ + Compute an emission probability matrix for the diploid case, and return it. + + The emission probability matrix is of size (num_sites, 8). The entries are indexed + as follows: + EQUAL_BOTH_HOM = 4 + UNEQUAL_BOTH_HOM = 0 + BOTH_HET = 7 + REF_HOM_OBS_HET = 1 + REF_HET_OBS_HOM = 2 + MISSING_INDEX = 3 + + Note that indices 5 and 6 are unused and set to negative infinity. + + By default, there is no scaling of mutation rates based on the number of alleles, + so that mutation probability is the probability of mutation to **any allele** + (therefore, summing over all the states that can be switched to). + + This means that we must rescale the probability of mutation to a different allele + by the number of alleles at the site. + + Optionally, scale mutation based on the number of alleles, so that mutation + probability is the probability of mutation **any given one** of the alleles. + The overall mutation probability is then (num_alleles - 1) * mutation probability. + + :param float/numpy.ndarray mu: Mutation probability. + :param int num_sites: Number of sites. + :param numpy.ndarray: Number of distinct alleles per site. + :param bool scale_mutation_rate: Scale mutation rate based on the number of alleles. + """ assert len(mu) == len( num_alleles ), "Arrays of mutation probability and number of alleles are unequal in length." @@ -326,15 +336,6 @@ def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate) return emission_matrix -@jit.numba_njit -def get_emission_probability_haploid(ref_allele, query_allele, site, emission_matrix): - if ref_allele == NONCOPY: - return 0.0 - else: - emission_index = get_index_in_emission_matrix_haploid(ref_allele, query_allele) - return emission_matrix[site, emission_index] - - @jit.numba_njit def get_emission_probability_diploid( ref_genotype, query_genotype, site, emission_matrix @@ -352,11 +353,13 @@ def get_emission_probability_diploid( def get_emission_probability_diploid_genotypes( ref_genotypes, query_genotype, site, emission_matrix ): - assert ref_genotypes.shape[0] == ref_genotypes.shape[1] - num_ref_haps = len(ref_genotypes) - emission_probs = np.zeros((num_ref_haps, num_ref_haps), dtype=np.float64) - for i in range(num_ref_haps): - for j in range(num_ref_haps): + assert ( + ref_genotypes.shape[0] == ref_genotypes.shape[1] + ), "Reference genotype matrix must be a square matrix." + num_ref_genotypes = len(ref_genotypes) + emission_probs = np.zeros((num_ref_genotypes, num_ref_genotypes), dtype=np.float64) + for i in range(num_ref_genotypes): + for j in range(num_ref_genotypes): if ref_genotypes[i, j] == NONCOPY: emission_probs[i, j] = 0.0 else: @@ -365,3 +368,24 @@ def get_emission_probability_diploid_genotypes( ) emission_probs[i, j] = emission_matrix[site, emission_index] return emission_probs + + +@jit.numba_njit +def get_index_in_emission_matrix_diploid(ref_genotype, query_genotype): + """ + Compare the implied unphased genotypes (allele dosages) of + the reference and query to get the index of the entry + in the emission probability matrix, and return the index. + + :param int ref_genotype: Reference genotype (allele dosage). + :param int query_genotype: Query genotype (allele dosage). + :return: Index in emission probability matrix. + :rtype: int + """ + if query_genotype == MISSING: + return MISSING_INDEX + else: + is_match = ref_genotype == query_genotype + is_ref_het = ref_genotype == 1 + is_query_het = query_genotype == 1 + return 4 * is_match + 2 * is_ref_het + is_query_het