Skip to content

Commit

Permalink
Implement compute_emission_probability
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Mar 4, 2024
1 parent 0581375 commit 4393ad7
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions python/tests/beagle_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,20 @@ def get_transition_probs(cm, h, ne):
return trans_probs


def compute_emission_probability(mismatch_prob, is_match, *, num_alleles=2):
"""
:param float mismatch_prob: Mismatch probability.
:param bool is_match: True if match, otherwise mismatch.
:param int num_alleles: Number of distinct alleles (default = 2).
:return: Emission probability
:rtype: float
"""
em_prob = mismatch_prob
if is_match:
em_prob = 1.0 - (num_alleles - 1) * mismatch_prob
return em_prob


@njit
def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs):
"""
Expand Down Expand Up @@ -353,9 +367,7 @@ def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs):
# Get allele at genotyped position i on reference haplotype j.
ref_a = ref_h[i, j]
# Get emission probability.
em_prob = mismatch_probs[i]
if query_a == ref_a:
em_prob = 1.0 - mismatch_probs[i]
em_prob = compute_emission_probability(mismatch_probs[i], query_a == ref_a)
fwd_mat[i, j] = em_prob
if i > 0:
fwd_mat[i, j] *= scale * fwd_mat[i - 1, j] + shift
Expand Down Expand Up @@ -390,17 +402,18 @@ def compute_backward_matrix(ref_h, query_h, trans_probs, mismatch_probs):
bwd_mat = np.zeros((m, h), dtype=np.float64)
bwd_mat[-1, :] = 1.0 / h # Initialise the last column.
for i in range(m - 2, -1, -1):
query_a = query_h[i + 1]
iP1 = i + 1
query_a = query_h[iP1]
for j in range(h):
ref_a = ref_h[i + 1, j]
em_prob = mismatch_probs[i + 1]
if ref_a == query_a:
em_prob = 1.0 - mismatch_probs[i + 1]
bwd_mat[i + 1, j] *= em_prob
site_sum = np.sum(bwd_mat[i + 1, :])
scale = (1 - trans_probs[i + 1]) / site_sum
shift = trans_probs[i + 1] / h
bwd_mat[i, :] = scale * bwd_mat[i + 1, :] + shift
ref_a = ref_h[iP1, j]
em_prob = compute_emission_probability(
mismatch_probs[iP1], query_a == ref_a
)
bwd_mat[iP1, j] *= em_prob
site_sum = np.sum(bwd_mat[iP1, :])
scale = (1 - trans_probs[iP1]) / site_sum
shift = trans_probs[iP1] / h
bwd_mat[i, :] = scale * bwd_mat[iP1, :] + shift
return bwd_mat


Expand Down

0 comments on commit 4393ad7

Please sign in to comment.