Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jan 14, 2024
1 parent 96f455e commit 3fc04e2
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions python/tests/beagle_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def interpolate_allele_prob(sm, ref_h, typed_pos, untyped_pos, typed_cm, untyped
Note that this function takes:
1. HMM state probability matrix across genotyped markers of size (m, h).
2. reference haplotypes subsetted to ungenotyped markers of size (x, h).
2. Reference haplotypes subsetted to ungenotyped markers of size (x, h).
:param numpy.ndarray sm: HMM state probability matrix at genotyped markers.
:param numpy.ndarray ref_h: Reference haplotypes subsetted to imputed markers.
Expand All @@ -316,17 +316,23 @@ def interpolate_allele_prob(sm, ref_h, typed_pos, untyped_pos, typed_cm, untyped
"""
# TODO: Allow for biallelic site matrix. Work with `_tskit.lshmm` properly.
alleles = np.arange(4) # ACGT
m = sm.shape[0]
x = ref_h.shape[0]
h = ref_h.shape[1]
x = len(untyped_pos)
weights, left_idx = get_weights(typed_pos, untyped_pos, typed_cm, untyped_cm)
p = np.zeros((x, len(alleles)), dtype=np.float32)
for i in range(x):
m = left_idx[i]
k = left_idx[i]
w = weights[i]
for j in range(h):
for a in alleles:
if ref_h[i, j] == a:
p[i, a] += weights[i] * sm[m, j]
p[i, a] += (1 - weights[i]) * sm[m + 1, j]
if k == m - 1:
# TODO: Consult BEAGLE source code.
pass
else:
p[i, a] += w * sm[k, j]
p[i, a] += (1 - w) * sm[k + 1, j]
# Rescale probabilities.
# TODO: Check if this is necessary. Could this be a subtle source of error?
p_rescaled = p / np.sum(p, axis=1)[:, np.newaxis]
Expand Down

0 comments on commit 3fc04e2

Please sign in to comment.