Skip to content

Commit

Permalink
Merge pull request #97 from szhan/reorganise_core
Browse files Browse the repository at this point in the history
Reorganise functions in core
  • Loading branch information
szhan authored Jun 17, 2024
2 parents d1d8829 + e0a0674 commit 945c0a1
Showing 1 changed file with 93 additions and 69 deletions.
162 changes: 93 additions & 69 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
NONCOPY = -2


""" Helper functions. """


# Helper functions.
# https://github.com/numba/numba/issues/1269
@jit.numba_njit
def np_apply_along_axis(func1d, axis, arr):
Expand Down Expand Up @@ -52,9 +50,7 @@ def np_argmax(array, axis):
return np_apply_along_axis(np.argmax, axis, array)


""" Functions used across different implementations of the LS HMM. """


# Functions used across different implementations of LS HMM. """
def convert_haplotypes_to_phased_genotypes(ref_panel):
"""
Convert a set of haplotypes into a matrix of diploid genotypes encoded as allele dosages,
Expand All @@ -68,19 +64,21 @@ def convert_haplotypes_to_phased_genotypes(ref_panel):
at a site, then the genotype at the site is assigned NONCOPY.
The input reference haplotypes is of size (m, n), and the output genotypes is of size (m, n, n),
where m = number of sites and n = number of reference haplotypes.
where:
m = number of sites.
n = number of reference haplotypes.
:param numpy.ndarray ref_panel: An array of reference haplotypes.
:return: An array of reference genotypes.
:rtype: numpy.ndarray
"""
ALLOWED_ALLELE_STATES = np.array([0, 1, NONCOPY], dtype=np.int32)
ALLOWED_ALLELE_STATES = np.array([0, 1, NONCOPY], dtype=np.int8)
assert np.all(
np.isin(np.unique(ref_panel), ALLOWED_ALLELE_STATES)
), f"Reference haplotypes contain illegal allele states."
num_sites = ref_panel.shape[0]
num_haps = ref_panel.shape[1]
genotypes = np.zeros((num_sites, num_haps, num_haps), dtype=np.int32) - np.inf
genotypes = np.zeros((num_sites, num_haps, num_haps), dtype=np.int8) - np.inf
for i in range(num_sites):
site_alleles = ref_panel[i, :]
genotypes[i, :, :] = np.add.outer(site_alleles, site_alleles)
Expand Down Expand Up @@ -108,14 +106,14 @@ def convert_haplotypes_to_unphased_genotypes(query):
:return: An array of query genotypes.
:rtype: numpy.ndarray
"""
ALLOWED_ALLELE_STATES = np.array([0, 1, MISSING], dtype=np.int32)
ALLOWED_ALLELE_STATES = np.array([0, 1, MISSING], dtype=np.int8)
assert np.all(
np.isin(np.unique(query), ALLOWED_ALLELE_STATES)
), f"Query haplotypes contain illegal allele states."
num_sites = query.shape[1]
num_haps = query.shape[0]
assert num_haps == 2, "Two haplotypes are expected in a diploid query."
genotypes = np.zeros((1, num_sites), dtype=np.int32) - np.inf
genotypes = np.zeros((1, num_sites), dtype=np.int8) - np.inf
genotypes[0, :] = np.sum(query, axis=0)
genotypes[0, np.any(query == MISSING, axis=0)] = MISSING
return genotypes
Expand All @@ -124,14 +122,14 @@ def convert_haplotypes_to_unphased_genotypes(query):
def check_genotype_matrix(genotype_matrix, num_sample_haps):
"""
Check that at each site the number of non-NONCOPY values in the reference panel
in the form of a genotype matrix is at most a maximum.
in the form of a genotype matrix is at most a maximum value.
The genotype matrix is an array of size (m, n),
where:
m = number of sites.
n = number of haplotypes (sample and ancestor) in the reference panel.
The maximum is equal to (2n - 1), where n is the number of sample haplotypes
The maximum value is equal to (2n - 1), where n is the number of sample haplotypes
in the genotype matrix, when a marginal tree is fully binary.
:param numpy.ndarray genotype_matrix: An array containing the reference haplotypes.
Expand Down Expand Up @@ -201,51 +199,14 @@ def check_alleles(alleles, num_sites):
return np.int8([len(alleles) for _ in range(num_sites)])
# Otherwise, process allele lists.
exclusion_set = np.array([MISSING, NONCOPY])
num_alleles = np.zeros(num_sites, dtype=np.int32)
num_alleles = np.zeros(num_sites, dtype=np.int8)
for i in range(num_sites):
uniq_alleles = np.unique(alleles[i])
num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set))
return num_alleles


@jit.numba_njit
def get_index_in_emission_matrix_haploid(ref_allele, query_allele):
is_allele_match = ref_allele == query_allele
is_query_missing = query_allele == MISSING
if is_allele_match or is_query_missing:
return 1
return 0


@jit.numba_njit
def get_index_in_emission_matrix_diploid(ref_genotype, query_genotype):
"""
Compare the implied unphased genotypes (allele dosages) of
the reference and query to get the index of the entry
in the emission probability matrix, and return the index.
"""
if query_genotype == MISSING:
return MISSING_INDEX
else:
is_match = ref_genotype == query_genotype
is_ref_het = ref_genotype == 1
is_query_het = query_genotype == 1
return 4 * is_match + 2 * is_ref_het + is_query_het


@jit.numba_njit
def get_index_in_emission_matrix_diploid_genotypes(
ref_genotypes, query_genotype, num_ref_haps
):
if query_genotype == MISSING:
return MISSING_INDEX * np.ones((num_ref_haps, num_ref_haps), dtype=np.int64)
else:
is_match = ref_genotypes == query_genotype
is_ref_het = ref_genotypes == 1
is_query_het = query_genotype == 1
return 4 * is_match + 2 * is_ref_het + is_query_het


# Functions to assign emission probabilities for haploid LS HMM.
@jit.numba_njit
def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate):
"""
Expand All @@ -266,10 +227,10 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate)
probability is the probability of mutation **any given one** of the alleles.
The overall mutation probability is then (num_alleles - 1) * mutation probability.
:param float/numpy.ndarray(dtype=np.float64) mu: Probability of mutation.
:param float/numpy.ndarray mu: Mutation probability.
:param int num_sites: Number of sites.
:param numpy.ndarray(dtype=np.int8): Number of distinct alleles per site.
:param bool scale_mutation_rate: Scale mutation rate based on the number of alleles if True (default).
:param numpy.ndarray: Number of distinct alleles per site.
:param bool scale_mutation_rate: Scale mutation rate based on the number of alleles.
"""
assert len(mu) == len(
num_alleles
Expand All @@ -294,8 +255,57 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate)
return emission_matrix


@jit.numba_njit
def get_emission_probability_haploid(ref_allele, query_allele, site, emission_matrix):
if ref_allele == NONCOPY:
return 0.0
else:
emission_index = get_index_in_emission_matrix_haploid(ref_allele, query_allele)
return emission_matrix[site, emission_index]


@jit.numba_njit
def get_index_in_emission_matrix_haploid(ref_allele, query_allele):
is_allele_match = ref_allele == query_allele
is_query_missing = query_allele == MISSING
if is_allele_match or is_query_missing:
return 1
return 0


# Functions to assign emission probabilities for diploid LS HMM.
@jit.numba_njit
def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate):
"""
Compute an emission probability matrix for the diploid case, and return it.
The emission probability matrix is of size (num_sites, 8). The entries are indexed
as follows:
EQUAL_BOTH_HOM = 4
UNEQUAL_BOTH_HOM = 0
BOTH_HET = 7
REF_HOM_OBS_HET = 1
REF_HET_OBS_HOM = 2
MISSING_INDEX = 3
Note that indices 5 and 6 are unused and set to negative infinity.
By default, there is no scaling of mutation rates based on the number of alleles,
so that mutation probability is the probability of mutation to **any allele**
(therefore, summing over all the states that can be switched to).
This means that we must rescale the probability of mutation to a different allele
by the number of alleles at the site.
Optionally, scale mutation based on the number of alleles, so that mutation
probability is the probability of mutation **any given one** of the alleles.
The overall mutation probability is then (num_alleles - 1) * mutation probability.
:param float/numpy.ndarray mu: Mutation probability.
:param int num_sites: Number of sites.
:param numpy.ndarray: Number of distinct alleles per site.
:param bool scale_mutation_rate: Scale mutation rate based on the number of alleles.
"""
assert len(mu) == len(
num_alleles
), "Arrays of mutation probability and number of alleles are unequal in length."
Expand Down Expand Up @@ -326,15 +336,6 @@ def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate)
return emission_matrix


@jit.numba_njit
def get_emission_probability_haploid(ref_allele, query_allele, site, emission_matrix):
if ref_allele == NONCOPY:
return 0.0
else:
emission_index = get_index_in_emission_matrix_haploid(ref_allele, query_allele)
return emission_matrix[site, emission_index]


@jit.numba_njit
def get_emission_probability_diploid(
ref_genotype, query_genotype, site, emission_matrix
Expand All @@ -352,11 +353,13 @@ def get_emission_probability_diploid(
def get_emission_probability_diploid_genotypes(
ref_genotypes, query_genotype, site, emission_matrix
):
assert ref_genotypes.shape[0] == ref_genotypes.shape[1]
num_ref_haps = len(ref_genotypes)
emission_probs = np.zeros((num_ref_haps, num_ref_haps), dtype=np.float64)
for i in range(num_ref_haps):
for j in range(num_ref_haps):
assert (
ref_genotypes.shape[0] == ref_genotypes.shape[1]
), "Reference genotype matrix must be a square matrix."
num_ref_genotypes = len(ref_genotypes)
emission_probs = np.zeros((num_ref_genotypes, num_ref_genotypes), dtype=np.float64)
for i in range(num_ref_genotypes):
for j in range(num_ref_genotypes):
if ref_genotypes[i, j] == NONCOPY:
emission_probs[i, j] = 0.0
else:
Expand All @@ -365,3 +368,24 @@ def get_emission_probability_diploid_genotypes(
)
emission_probs[i, j] = emission_matrix[site, emission_index]
return emission_probs


@jit.numba_njit
def get_index_in_emission_matrix_diploid(ref_genotype, query_genotype):
"""
Compare the implied unphased genotypes (allele dosages) of
the reference and query to get the index of the entry
in the emission probability matrix, and return the index.
:param int ref_genotype: Reference genotype (allele dosage).
:param int query_genotype: Query genotype (allele dosage).
:return: Index in emission probability matrix.
:rtype: int
"""
if query_genotype == MISSING:
return MISSING_INDEX
else:
is_match = ref_genotype == query_genotype
is_ref_het = ref_genotype == 1
is_query_het = query_genotype == 1
return 4 * is_match + 2 * is_ref_het + is_query_het

0 comments on commit 945c0a1

Please sign in to comment.