Skip to content

Commit

Permalink
Rename variables
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jan 14, 2024
1 parent b43a80a commit 22e1cb7
Showing 1 changed file with 49 additions and 55 deletions.
104 changes: 49 additions & 55 deletions python/tests/beagle_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,14 @@ def compute_state_prob_matrix(fm, bm):


@njit
def _interpolate_allele_prob(
sm, ref_h, genotyped_pos, ungenotyped_pos, genotyped_cm, ungenotyped_cm
):
def _interpolate_allele_prob(sm, ref_h, typed_pos, untyped_pos, typed_cm, untyped_cm):
"""
DEPRECATED: Use `interpolate_allele_prob`, which is 2x as fast in benchmark runs.
"""
alleles = np.arange(4) # ACGT
x = len(ungenotyped_pos)
x = len(untyped_pos)
weights, marker_interval_start = get_weights(
genotyped_pos, ungenotyped_pos, genotyped_cm, ungenotyped_cm
typed_pos, untyped_pos, typed_cm, untyped_cm
)
p = np.zeros((x, len(alleles)), dtype=np.float32)
for a in alleles:
Expand All @@ -306,9 +304,7 @@ def _interpolate_allele_prob(


@njit
def interpolate_allele_prob(
sm, ref_h, genotyped_pos, ungenotyped_pos, genotyped_cm, ungenotyped_cm
):
def interpolate_allele_prob(sm, ref_h, typed_pos, untyped_pos, typed_cm, untyped_cm):
"""
Compute the interpolated allele probabilities at ungenotyped markers of
a query haplotype following Equation 1 of BB2016.
Expand All @@ -333,10 +329,8 @@ def interpolate_allele_prob(
"""
alleles = np.arange(4) # ACGT
ref_panel_size = ref_h.shape[1]
x = len(ungenotyped_pos)
weights, left_idx = get_weights(
genotyped_pos, ungenotyped_pos, genotyped_cm, ungenotyped_cm
)
x = len(untyped_pos)
weights, left_idx = get_weights(typed_pos, untyped_pos, typed_cm, untyped_cm)
p = np.zeros((x, len(alleles)), dtype=np.float32)
for i in range(x):
m = left_idx[i]
Expand Down Expand Up @@ -399,35 +393,35 @@ def run_beagle(
:rtype: tuple(numpy.ndarray, numpy.ndarray)
"""
# Set indices of markers.
genotyped_idx = np.where(query_h != -1)[0]
ungenotyped_idx = np.where(query_h == -1)[0]
typed_idx = np.where(query_h != -1)[0]
untyped_idx = np.where(query_h == -1)[0]
# Set site positions of markers.
genotyped_pos = pos[genotyped_idx]
ungenotyped_pos = pos[ungenotyped_idx]
typed_pos = pos[typed_idx]
untyped_pos = pos[untyped_idx]
# Get genetic map positions of markers.
genotyped_cm = convert_to_genetic_map_positions(genotyped_pos, genetic_map)
ungenotyped_cm = convert_to_genetic_map_positions(ungenotyped_pos, genetic_map)
typed_cm = convert_to_genetic_map_positions(typed_pos, genetic_map)
untyped_cm = convert_to_genetic_map_positions(untyped_pos, genetic_map)
# Subset haplotypes.
ref_h_genotyped = ref_h[genotyped_idx, :]
ref_h_ungenotyped = ref_h[ungenotyped_idx, :]
query_h_genotyped = query_h[genotyped_idx]
ref_h_typed = ref_h[typed_idx, :]
ref_h_untyped = ref_h[untyped_idx, :]
query_h_typed = query_h[typed_idx]
# Set switch and mismatch probabilities at genotyped markers.
mu = get_mismatch_prob(genotyped_pos, error_rate=error_rate)
mu = get_mismatch_prob(typed_pos, error_rate=error_rate)
# TODO: Use genetic map if provided.
h = ref_h.shape[1]
rho = get_switch_prob(genotyped_cm, h=h, ne=ne) # Be careful!
rho = get_switch_prob(typed_cm, h=h, ne=ne) # Be careful!
# Compute the matrices at genotyped markers.
fm = compute_forward_matrix(ref_h_genotyped, query_h_genotyped, rho, mu)
bm = compute_backward_matrix(ref_h_genotyped, query_h_genotyped, rho, mu)
fm = compute_forward_matrix(ref_h_typed, query_h_typed, rho, mu)
bm = compute_backward_matrix(ref_h_typed, query_h_typed, rho, mu)
sm = compute_state_prob_matrix(fm, bm)
# Interpolate allele probabilities at ungenotyped markers.
i_allele_prob = interpolate_allele_prob(
sm,
ref_h_ungenotyped,
genotyped_pos,
ungenotyped_pos,
genotyped_cm,
ungenotyped_cm,
ref_h_untyped,
typed_pos,
untyped_pos,
typed_cm,
untyped_cm,
)
# Get MAP alleles at ungenotyped markers.
imputed_alleles, max_allele_prob = get_map_alleles(i_allele_prob)
Expand All @@ -450,9 +444,9 @@ def run_tsimpute(
TODO: Put this function elsewhere.
TODO: Set default precision. What should it be?
:param numpy.ndarray ref_ts: Tree sequence containing reference haplotypes.
:param numpy.ndarray ref_ts: Tree sequence with reference haplotypes.
:param numpy.ndarray query_h: One query haplotype.
:param numpy.ndarray pos: Physical positions of all the markers.
:param numpy.ndarray pos: Physical positions of all the markers (bp).
:param numpy.ndarray mu: Mutation rate.
:param numpy.ndarray rho: Recombination rate.
:param int precision: Precision when running LS HMM (default = 22).
Expand All @@ -461,46 +455,46 @@ def run_tsimpute(
:return: Imputed alleles and their interpolated probabilities.
:rtype: tuple(numpy.ndarray, numpy.ndarray)
"""
# Set indices of markers.
genotyped_idx = np.where(query_h != -1)[0]
ungenotyped_idx = np.where(query_h == -1)[0]
# Set markers indices.
typed_idx = np.where(query_h != -1)[0]
untyped_idx = np.where(query_h == -1)[0]
# Set site positions of markers.
genotyped_pos = pos[genotyped_idx]
ungenotyped_pos = pos[ungenotyped_idx]
typed_pos = pos[typed_idx]
untyped_pos = pos[untyped_idx]
# Get genetic map positions of markers.
genotyped_cm = convert_to_genetic_map_positions(genotyped_pos, genetic_map)
ungenotyped_cm = convert_to_genetic_map_positions(ungenotyped_pos, genetic_map)
typed_cm = convert_to_genetic_map_positions(typed_pos, genetic_map)
untyped_cm = convert_to_genetic_map_positions(untyped_pos, genetic_map)
# Get parameters at genotyped markers.
mu = mu[genotyped_idx]
rho = rho[genotyped_idx]
mu = mu[typed_idx]
rho = rho[typed_idx]
# Subset haplotypes.
ref_ts_genotyped = ref_ts.delete_sites(site_ids=ungenotyped_idx)
ref_ts_ungenotyped = ref_ts.delete_sites(site_ids=genotyped_idx)
ref_h_ungenotyped = ref_ts_ungenotyped.genotype_matrix(alleles=tskit.ALLELES_ACGT)
query_h_genotyped = query_h[genotyped_idx]
ref_ts_typed = ref_ts.delete_sites(site_ids=untyped_idx)
ref_ts_untyped = ref_ts.delete_sites(site_ids=typed_idx)
ref_h_untyped = ref_ts_untyped.genotype_matrix(alleles=tskit.ALLELES_ACGT)
query_h_typed = query_h[typed_idx]
# Get matrices from tree sequence.
fm = _tskit.CompressedMatrix(ref_ts_genotyped._ll_tree_sequence)
bm = _tskit.CompressedMatrix(ref_ts_genotyped._ll_tree_sequence)
fm = _tskit.CompressedMatrix(ref_ts_typed._ll_tree_sequence)
bm = _tskit.CompressedMatrix(ref_ts_typed._ll_tree_sequence)
# WARN: Be careful with the positional arguments rho and mu!!!
# WARN: Be careful with the argument `acgt_alleles`!!!
ls_hmm = _tskit.LsHmm(
ref_ts_genotyped._ll_tree_sequence,
ref_ts_typed._ll_tree_sequence,
rho,
mu,
acgt_alleles=True,
precision=precision,
)
ls_hmm.forward_matrix(query_h_genotyped.T, fm)
ls_hmm.backward_matrix(query_h_genotyped.T, fm.normalisation_factor, bm)
ls_hmm.forward_matrix(query_h_typed.T, fm)
ls_hmm.backward_matrix(query_h_typed.T, fm.normalisation_factor, bm)
sm = compute_state_prob_matrix(fm.decode(), bm.decode())
# Perform linear interpolation.
int_allele_prob = interpolate_allele_prob(
sm,
ref_h_ungenotyped,
genotyped_pos,
ungenotyped_pos,
genotyped_cm,
ungenotyped_cm,
ref_h_untyped,
typed_pos,
untyped_pos,
typed_cm,
untyped_cm,
)
imputed_alleles, max_allele_prob = get_map_alleles(int_allele_prob)
return (imputed_alleles, max_allele_prob)

0 comments on commit 22e1cb7

Please sign in to comment.