Skip to content

Commit

Permalink
Implement mutation rate scaling based on the number of alleles
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Mar 4, 2024
1 parent e601248 commit 908715c
Showing 1 changed file with 32 additions and 7 deletions.
39 changes: 32 additions & 7 deletions python/tests/beagle_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,14 @@ def get_transition_probs(cm, h, ne):

def compute_emission_probability(mismatch_prob, is_match, *, num_alleles=2):
"""
Compute emission probability at a site based on whether the alleles
carried by a query haplotype and a reference haplotype match at the site.
The emission probability may be scaled according to the number of distinct
segregating alleles. By default, it is assumeed the site is biallelic.
:param float mismatch_prob: Mismatch probability.
:param bool is_match: True if match, otherwise False.
:param bool is_match: True if matched, otherwise False.
:param int num_alleles: Number of distinct alleles (default = 2).
:return: Emission probability
:rtype: float
Expand All @@ -337,7 +343,9 @@ def compute_emission_probability(mismatch_prob, is_match, *, num_alleles=2):


@njit
def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs):
def compute_forward_matrix(
ref_h, query_h, trans_probs, mismatch_probs, *, num_alleles=2
):
"""
Implement Li and Stephens forward algorithm.
Expand All @@ -350,6 +358,7 @@ def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs):
:param numpy.ndarray query_h: One query haplotype.
:param numpy.ndarray trans_probs: Transition probabilities.
:param numpy.ndarray mismatch_probs: Mismatch probabilities.
:param int num_alleles: Number of distinct alleles (default = 2).
:return: Forward probability matrix.
:rtype: numpy.ndarray
"""
Expand All @@ -367,7 +376,9 @@ def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs):
# Get allele at genotyped position i on reference haplotype j.
ref_a = ref_h[i, j]
# Get emission probability.
em_prob = compute_emission_probability(mismatch_probs[i], query_a == ref_a)
em_prob = compute_emission_probability(
mismatch_probs[i], query_a == ref_a, num_alleles=num_alleles
)
fwd_mat[i, j] = em_prob
if i > 0:
fwd_mat[i, j] *= scale * fwd_mat[i - 1, j] + shift
Expand All @@ -378,7 +389,9 @@ def compute_forward_matrix(ref_h, query_h, trans_probs, mismatch_probs):


@njit
def compute_backward_matrix(ref_h, query_h, trans_probs, mismatch_probs):
def compute_backward_matrix(
ref_h, query_h, trans_probs, mismatch_probs, *, num_alleles=2
):
"""
Implement Li and Stephens backward algorithm.
Expand All @@ -394,6 +407,7 @@ def compute_backward_matrix(ref_h, query_h, trans_probs, mismatch_probs):
:param numpy.ndarray query_h: One query haplotype.
:param numpy.ndarray trans_probs: Transition probabilities.
:param numpy.ndarray mismatch_probs: Mismatch probabilities.
:param int num_alleles: Number of distinct alleles (default = 2).
:return: Backward probability matrix.
:rtype: numpy.ndarray
"""
Expand All @@ -407,7 +421,9 @@ def compute_backward_matrix(ref_h, query_h, trans_probs, mismatch_probs):
for j in range(h):
ref_a = ref_h[iP1, j]
em_prob = compute_emission_probability(
mismatch_probs[iP1], query_a == ref_a
mismatch_probs[iP1],
query_a == ref_a,
num_alleles=num_alleles,
)
bwd_mat[iP1, j] *= em_prob
site_sum = np.sum(bwd_mat[iP1, :])
Expand Down Expand Up @@ -635,6 +651,7 @@ def run_interpolation_beagle(
"Check the reference and query haplotypes use the same allele encoding.",
stacklevel=1,
)
num_alleles = len(tskit.ALLELES_ACGT)
h = ref_h.shape[1] # Number of reference haplotypes.
# Separate indices of genotyped and ungenotyped positions.
idx_typed = np.where(query_h != tskit.MISSING_DATA)[0]
Expand All @@ -654,10 +671,18 @@ def run_interpolation_beagle(
mismatch_probs = get_mismatch_probs(pos_typed, error_rate=error_rate)
# Compute matrices at genotyped positions.
fwd_mat = compute_forward_matrix(
ref_h_typed, query_h_typed, trans_probs, mismatch_probs
ref_h_typed,
query_h_typed,
trans_probs,
mismatch_probs,
num_alleles=num_alleles,
)
bwd_mat = compute_backward_matrix(
ref_h_typed, query_h_typed, trans_probs, mismatch_probs
ref_h_typed,
query_h_typed,
trans_probs,
mismatch_probs,
num_alleles=num_alleles,
)
state_mat = compute_state_prob_matrix(fwd_mat, bwd_mat)
# Interpolate allele probabilities.
Expand Down

0 comments on commit 908715c

Please sign in to comment.