From cd3683cf96c8cf8ebb74b600c961dc135c914fcf Mon Sep 17 00:00:00 2001 From: szhan Date: Mon, 17 Jun 2024 22:06:00 +0100 Subject: [PATCH] Refactor function to get emission probability in the haploid case --- lshmm/core.py | 37 ++++++++++++++++++++++++++----------- tests/lsbase.py | 2 +- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/lshmm/core.py b/lshmm/core.py index 60d17b6..d4ca0a7 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -257,20 +257,35 @@ def get_emission_matrix_haploid(mu, num_sites, num_alleles, scale_mutation_rate) @jit.numba_njit def get_emission_probability_haploid(ref_allele, query_allele, site, emission_matrix): + """ + Return the emission probability at a specified site for the haploid case, + given an emission probability matrix. + + The emission probability matrix is an array of size (m, 2), where m = number of sites. + + This handle multiallelic sites. + + :param int ref_allele: Reference allele. + :param int query_allele: Query allele. + :param int site: Site index. + :param numpy.ndarray emission_matrix: Emission probability matrix. + :return: Emission probability. + :rtype: float + """ + assert ref_allele != MISSING, "Reference allele cannot be MISSING." + assert query_allele != NONCOPY, "Query allele cannot be NONCOPY." + assert ( + emission_matrix.shape[1] == 2 + ), "Emission probability matrix has incorrect shape." if ref_allele == NONCOPY: return 0.0 + elif query_allele == MISSING: + return 1.0 else: - emission_index = get_index_in_emission_matrix_haploid(ref_allele, query_allele) - return emission_matrix[site, emission_index] - - -@jit.numba_njit -def get_index_in_emission_matrix_haploid(ref_allele, query_allele): - is_allele_match = ref_allele == query_allele - is_query_missing = query_allele == MISSING - if is_allele_match or is_query_missing: - return 1 - return 0 + if ref_allele != query_allele: + return emission_matrix[site, 0] + else: + return emission_matrix[site, 1] # Functions to assign emission probabilities for diploid LS HMM. diff --git a/tests/lsbase.py b/tests/lsbase.py index 4d16d7d..4cfacd7 100644 --- a/tests/lsbase.py +++ b/tests/lsbase.py @@ -77,7 +77,7 @@ def get_examples_haploid(self, ts, include_ancestors): query_miss_mid = query1.copy() query_miss_mid[0, ts.num_sites // 2] = core.MISSING query_miss_most = query1.copy() - query_miss_most[0, 1:] = core.MISSING + query_miss_most[0, 2:] = core.MISSING queries = [query1, query2, query_miss_last, query_miss_mid, query_miss_most] # Exclude the arbitrarily chosen queries from the reference panel. ref_panel = ref_panel[:, 2:]