From 908715c824d54f0e6524f9932703854d89b0a828 Mon Sep 17 00:00:00 2001 From: Shing Zhan Date: Mon, 4 Mar 2024 13:44:35 +0000 Subject: [PATCH] Implement mutation rate scaling based on the number of alleles --- python/tests/beagle_numba.py | 39 +++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/python/tests/beagle_numba.py b/python/tests/beagle_numba.py index ccfa590808..7360fc7b4f 100644 --- a/python/tests/beagle_numba.py +++ b/python/tests/beagle_numba.py @@ -324,8 +324,14 @@ def get_transition_probs(cm, h, ne): def compute_emission_probability(mismatch_prob, is_match, *, num_alleles=2): """ + Compute emission probability at a site based on whether the alleles + carried by a query haplotype and a reference haplotype match at the site. + + The emission probability may be scaled according to the number of distinct + segregating alleles. By default, it is assumeed the site is biallelic. + :param float mismatch_prob: Mismatch probability. - :param bool is_match: True if match, otherwise False. + :param bool is_match: True if matched, otherwise False. :param int num_alleles: Number of distinct alleles (default = 2). :return: Emission probability :rtype: float @@ -337,7 +343,9 @@ def compute_emission_probability(mismatch_prob, is_match, *, num_alleles=2): @njit -def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs): +def compute_forward_matrix( + ref_h, query_h, trans_probs, mismatch_probs, *, num_alleles=2 +): """ Implement Li and Stephens forward algorithm. @@ -350,6 +358,7 @@ def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs): :param numpy.ndarray query_h: One query haplotype. :param numpy.ndarray trans_probs: Transition probabilities. :param numpy.ndarray mismatch_probs: Mismatch probabilities. + :param int num_alleles: Number of distinct alleles (default = 2). :return: Forward probability matrix. :rtype: numpy.ndarray """ @@ -367,7 +376,9 @@ def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs): # Get allele at genotyped position i on reference haplotype j. ref_a = ref_h[i, j] # Get emission probability. - em_prob = compute_emission_probability(mismatch_probs[i], query_a == ref_a) + em_prob = compute_emission_probability( + mismatch_probs[i], query_a == ref_a, num_alleles=num_alleles + ) fwd_mat[i, j] = em_prob if i > 0: fwd_mat[i, j] *= scale * fwd_mat[i - 1, j] + shift @@ -378,7 +389,9 @@ def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs): @njit -def compute_backward_matrix(ref_h, query_h, trans_probs, mismatch_probs): +def compute_backward_matrix( + ref_h, query_h, trans_probs, mismatch_probs, *, num_alleles=2 +): """ Implement Li and Stephens backward algorithm. @@ -394,6 +407,7 @@ def compute_backward_matrix(ref_h, query_h, trans_probs, mismatch_probs): :param numpy.ndarray query_h: One query haplotype. :param numpy.ndarray trans_probs: Transition probabilities. :param numpy.ndarray mismatch_probs: Mismatch probabilities. + :param int num_alleles: Number of distinct alleles (default = 2). :return: Backward probability matrix. :rtype: numpy.ndarray """ @@ -407,7 +421,9 @@ def compute_backward_matrix(ref_h, query_h, trans_probs, mismatch_probs): for j in range(h): ref_a = ref_h[iP1, j] em_prob = compute_emission_probability( - mismatch_probs[iP1], query_a == ref_a + mismatch_probs[iP1], + query_a == ref_a, + num_alleles=num_alleles, ) bwd_mat[iP1, j] *= em_prob site_sum = np.sum(bwd_mat[iP1, :]) @@ -635,6 +651,7 @@ def run_interpolation_beagle( "Check the reference and query haplotypes use the same allele encoding.", stacklevel=1, ) + num_alleles = len(tskit.ALLELES_ACGT) h = ref_h.shape[1] # Number of reference haplotypes. # Separate indices of genotyped and ungenotyped positions. idx_typed = np.where(query_h != tskit.MISSING_DATA)[0] @@ -654,10 +671,18 @@ def run_interpolation_beagle( mismatch_probs = get_mismatch_probs(pos_typed, error_rate=error_rate) # Compute matrices at genotyped positions. fwd_mat = compute_forward_matrix( - ref_h_typed, query_h_typed, trans_probs, mismatch_probs + ref_h_typed, + query_h_typed, + trans_probs, + mismatch_probs, + num_alleles=num_alleles, ) bwd_mat = compute_backward_matrix( - ref_h_typed, query_h_typed, trans_probs, mismatch_probs + ref_h_typed, + query_h_typed, + trans_probs, + mismatch_probs, + num_alleles=num_alleles, ) state_mat = compute_state_prob_matrix(fwd_mat, bwd_mat) # Interpolate allele probabilities.