Skip to content

Commit

Permalink
Merge pull request #99 from szhan/refactor_get_emission_prob_haploid
Browse files Browse the repository at this point in the history
Refactor function to get emission probability in the haploid case
  • Loading branch information
szhan authored Jun 18, 2024
2 parents 4cc870d + cd3683c commit ed35954
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
37 changes: 26 additions & 11 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,20 +257,35 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate)

@jit.numba_njit
def get_emission_probability_haploid(ref_allele, query_allele, site, emission_matrix):
"""
Return the emission probability at a specified site for the haploid case,
given an emission probability matrix.
The emission probability matrix is an array of size (m, 2), where m = number of sites.
This handle multiallelic sites.
:param int ref_allele: Reference allele.
:param int query_allele: Query allele.
:param int site: Site index.
:param numpy.ndarray emission_matrix: Emission probability matrix.
:return: Emission probability.
:rtype: float
"""
assert ref_allele != MISSING, "Reference allele cannot be MISSING."
assert query_allele != NONCOPY, "Query allele cannot be NONCOPY."
assert (
emission_matrix.shape[1] == 2
), "Emission probability matrix has incorrect shape."
if ref_allele == NONCOPY:
return 0.0
elif query_allele == MISSING:
return 1.0
else:
emission_index = get_index_in_emission_matrix_haploid(ref_allele, query_allele)
return emission_matrix[site, emission_index]


@jit.numba_njit
def get_index_in_emission_matrix_haploid(ref_allele, query_allele):
is_allele_match = ref_allele == query_allele
is_query_missing = query_allele == MISSING
if is_allele_match or is_query_missing:
return 1
return 0
if ref_allele != query_allele:
return emission_matrix[site, 0]
else:
return emission_matrix[site, 1]


# Functions to assign emission probabilities for diploid LS HMM.
Expand Down
2 changes: 1 addition & 1 deletion tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_examples_haploid(self, ts, include_ancestors):
query_miss_mid = query1.copy()
query_miss_mid[0, ts.num_sites // 2] = core.MISSING
query_miss_most = query1.copy()
query_miss_most[0, 1:] = core.MISSING
query_miss_most[0, 2:] = core.MISSING
queries = [query1, query2, query_miss_last, query_miss_mid, query_miss_most]
# Exclude the arbitrarily chosen queries from the reference panel.
ref_panel = ref_panel[:, 2:]
Expand Down

0 comments on commit ed35954

Please sign in to comment.