diff --git a/python/tests/beagle_numba.py b/python/tests/beagle_numba.py index e3421cea28..597f2b8aac 100644 --- a/python/tests/beagle_numba.py +++ b/python/tests/beagle_numba.py @@ -322,6 +322,20 @@ def get_transition_probs(cm, h, ne): return trans_probs +def compute_emission_probability(mismatch_prob, is_match, *, num_alleles=2): + """ + :param float mismatch_prob: Mismatch probability. + :param bool is_match: True if match, otherwise mismatch. + :param int num_alleles: Number of distinct alleles (default = 2). + :return: Emission probability + :rtype: float + """ + em_prob = mismatch_prob + if is_match: + em_prob = 1.0 - (num_alleles - 1) * mismatch_prob + return em_prob + + @njit def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs): """ @@ -353,9 +367,7 @@ def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs): # Get allele at genotyped position i on reference haplotype j. ref_a = ref_h[i, j] # Get emission probability. - em_prob = mismatch_probs[i] - if query_a == ref_a: - em_prob = 1.0 - mismatch_probs[i] + em_prob = compute_emission_probability(mismatch_probs[i], query_a == ref_a) fwd_mat[i, j] = em_prob if i > 0: fwd_mat[i, j] *= scale * fwd_mat[i - 1, j] + shift @@ -390,17 +402,18 @@ def compute_backward_matrix(ref_h, query_h, trans_probs, mismatch_probs): bwd_mat = np.zeros((m, h), dtype=np.float64) bwd_mat[-1, :] = 1.0 / h # Initialise the last column. for i in range(m - 2, -1, -1): - query_a = query_h[i + 1] + iP1 = i + 1 + query_a = query_h[iP1] for j in range(h): - ref_a = ref_h[i + 1, j] - em_prob = mismatch_probs[i + 1] - if ref_a == query_a: - em_prob = 1.0 - mismatch_probs[i + 1] - bwd_mat[i + 1, j] *= em_prob - site_sum = np.sum(bwd_mat[i + 1, :]) - scale = (1 - trans_probs[i + 1]) / site_sum - shift = trans_probs[i + 1] / h - bwd_mat[i, :] = scale * bwd_mat[i + 1, :] + shift + ref_a = ref_h[iP1, j] + em_prob = compute_emission_probability( + mismatch_probs[iP1], query_a == ref_a + ) + bwd_mat[iP1, j] *= em_prob + site_sum = np.sum(bwd_mat[iP1, :]) + scale = (1 - trans_probs[iP1]) / site_sum + shift = trans_probs[iP1] / h + bwd_mat[i, :] = scale * bwd_mat[iP1, :] + shift return bwd_mat