Skip to content

Commit

Permalink
Allow only for ACGT encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Sep 15, 2023
1 parent 89aef67 commit 4412e31
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 16 deletions.
11 changes: 3 additions & 8 deletions python/tests/beagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,7 @@ def compute_state_probability_matrix(fm, bm):
return sm


def interpolate_allele_probabilities(
sm, ref_h, alleles, genotyped_pos, ungenotyped_pos
):
def interpolate_allele_probabilities(sm, ref_h, genotyped_pos, ungenotyped_pos):
"""
Compute the interpolated allele probabilities at ungenotyped markers of
a query haplotype following Equation 1 of BB2016.
Expand All @@ -444,19 +442,16 @@ def interpolate_allele_probabilities(
:param numpy.ndarray sm: HMM state probability matrix at genotyped markers.
:param numpy.ndarray ref_h: Reference haplotypes subsetted to ungenotyped markers.
:param numpy.ndarray alleles: Alleles (ancestral/derived or ACGT encoding).
:param numpy.ndarray genotyped_pos: Site positions at genotyped markers.
:param numpy.ndarray ungenotyped_pos: Site positions at ungenotyped markers.
:return: Interpolated allele probabilities.
:rtype: numpy.ndarray
"""
alleles = np.arange(4) # ACGT encoding
assert not np.any(sm < 0), "HMM state probability matrix has negative values."
assert not np.any(np.isnan(sm)), "HMM state probability matrix has NaN values."
m = sm.shape[0] - 1
h = sm.shape[1]
assert np.all(
np.isin(ref_h, alleles)
), f"Reference haplotypes have alleles absent in {alleles}."
assert m == len(genotyped_pos)
x = len(ungenotyped_pos)
assert (x, h) == ref_h.shape
Expand Down Expand Up @@ -600,7 +595,7 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6, debug=False):
# Interpolate allele probabilities at ungenotyped markers.
alleles = np.arange(4) # ACGT
i_allele_probs = interpolate_allele_probabilities(
sm, ref_h_ungenotyped, alleles, genotyped_pos, ungenotyped_pos
sm, ref_h_ungenotyped, genotyped_pos, ungenotyped_pos
)
if debug:
print("Interpolated allele probabilities")
Expand Down
12 changes: 4 additions & 8 deletions python/tests/beagle_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,7 @@ def compute_state_probability_matrix(fm, bm):


@njit
def interpolate_allele_probabilities(
sm, ref_h, alleles, genotyped_pos, ungenotyped_pos
):
def interpolate_allele_probabilities(sm, ref_h, genotyped_pos, ungenotyped_pos):
"""
Compute the interpolated allele probabilities at ungenotyped markers of
a query haplotype following Equation 1 of BB2016.
Expand All @@ -235,12 +233,12 @@ def interpolate_allele_probabilities(
:param numpy.ndarray sm: HMM state probability matrix at genotyped markers.
:param numpy.ndarray ref_h: Reference haplotypes subsetted to imputed markers.
:param numpy.ndarray alleles: Alleles (ancestral/derived or ACGT encoding).
:param numpy.ndarray genotyped_pos: Site positions at genotyped markers.
:param numpy.ndarray ungenotyped_pos: Site positions at ungenotyped markers.
:return: Interpolated allele probabilities.
:rtype: numpy.ndarray
"""
alleles = np.arange(4) # ACGT encoding
x = len(ungenotyped_pos)
weights, marker_interval_start = get_weights(genotyped_pos, ungenotyped_pos)
p = np.zeros((x, len(alleles)), dtype=np.float64)
Expand Down Expand Up @@ -297,7 +295,6 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6):
:return: Imputed alleles and their associated probabilities.
:rtype: tuple(numpy.ndarray, numpy.ndarray)
"""
alleles = np.arange(4) # ACGT encoding
# Indices of markers.
genotyped_pos_idx = np.where(query_h != -1)[0]
ungenotyped_pos_idx = np.where(query_h == -1)[0]
Expand All @@ -321,15 +318,14 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6):
ref_h_ungenotyped = ref_h[ungenotyped_pos_idx, :]
# Interpolate allele probabilities at ungenotyped markers.
i_allele_probs = interpolate_allele_probabilities(
sm, ref_h_ungenotyped, alleles, genotyped_pos, ungenotyped_pos
sm, ref_h_ungenotyped, genotyped_pos, ungenotyped_pos
)
# Get MAP alleles at ungenotyped markers.
imputed_alleles, max_allele_probs = get_map_alleles(i_allele_probs)
return (imputed_alleles, max_allele_probs)


def run_tsimpute(ref_ts, query_h, pos):
alleles = np.arange(4) # ACGT encoding
# Prepare marker positions.
genotyped_site_ids = np.where(query_h != -1)[0]
genotyped_pos = pos[genotyped_site_ids]
Expand All @@ -354,7 +350,7 @@ def run_tsimpute(ref_ts, query_h, pos):
sm = compute_state_probability_matrix(fm.decode(), bm.decode())
# Interpolate allele probabilities.
allele_probs = interpolate_allele_probabilities(
sm, ref_h_x, alleles, genotyped_pos, imputed_pos
sm, ref_h_x, genotyped_pos, imputed_pos
)
# Get MAP alleles at imputed markers.
imputed_alleles, max_allele_probs = get_map_alleles(allele_probs)
Expand Down

0 comments on commit 4412e31

Please sign in to comment.