Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Sep 17, 2023
1 parent 110e253 commit e4e922f
Showing 1 changed file with 29 additions and 28 deletions.
57 changes: 29 additions & 28 deletions python/tests/beagle_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,27 +296,25 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6):
:return: Imputed alleles and their associated probabilities.
:rtype: tuple(numpy.ndarray, numpy.ndarray)
"""
# Indices of markers.
# Set indices of markers.
genotyped_pos_idx = np.where(query_h != -1)[0]
ungenotyped_pos_idx = np.where(query_h == -1)[0]
# Site positions of markers.
# Set site positions of markers.
genotyped_pos = pos[genotyped_pos_idx]
ungenotyped_pos = pos[ungenotyped_pos_idx]
h = ref_h.shape[1]
# Subset haplotypes to genotyped markers.
# Subset haplotypes.
ref_h_genotyped = ref_h[genotyped_pos_idx, :]
ref_h_ungenotyped = ref_h[ungenotyped_pos_idx, :]
query_h_genotyped = query_h[genotyped_pos_idx]
# Set switch and mismatch probabilities at genotyped markers.
mu = get_mismatch_prob(genotyped_pos, miscall_rate=miscall_rate)
rho = get_switch_prob(genotyped_pos, h, ne=ne)
rho = get_switch_prob(genotyped_pos, h=ref_h.shape[1], ne=ne)
# Compute the HMM matrices at genotyped markers.
fm = compute_forward_probability_matrix(ref_h_genotyped, query_h_genotyped, rho, mu)
bm = compute_backward_probability_matrix(
ref_h_genotyped, query_h_genotyped, rho, mu
)
sm = compute_state_probability_matrix(fm, bm)
# Subset the reference haplotypes to ungenotyped markers.
ref_h_ungenotyped = ref_h[ungenotyped_pos_idx, :]
# Interpolate allele probabilities at ungenotyped markers.
i_allele_probs = interpolate_allele_probabilities(
sm, ref_h_ungenotyped, genotyped_pos, ungenotyped_pos
Expand All @@ -327,32 +325,35 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6):


def run_tsimpute(ref_ts, query_h, pos):
# Prepare marker positions.
genotyped_site_ids = np.where(query_h != -1)[0]
genotyped_pos = pos[genotyped_site_ids]
imputed_site_ids = np.where(query_h == -1)[0]
imputed_pos = pos[imputed_site_ids]
# Note that parametrization of BEAGLE and tsinfer
# Set indices of markers.
genotyped_pos_idx = np.where(query_h != -1)[0]
ungenotyped_pos_idx = np.where(query_h == -1)[0]
# Set site positions of markers.
genotyped_pos = pos[genotyped_pos_idx]
ungenotyped_pos = pos[ungenotyped_pos_idx]
# Set parametrization, which differs between BEAGLE and tsinfer.
mu = get_mismatch_prob(genotyped_pos, miscall_rate=1e-8)
rho = get_switch_prob(genotyped_pos, h=ref_ts.num_samples, ne=10_000.0)
rho /= 1e5
# Prepare reference haplotypes.
ref_ts_m = ref_ts.delete_sites(site_ids=imputed_site_ids)
ref_ts_x = ref_ts.delete_sites(site_ids=genotyped_site_ids)
ref_h_x = ref_ts_x.genotype_matrix()
query_h_m = query_h[genotyped_site_ids].astype(np.int32)
# Get forward and backward matrices from ts.
fm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence)
bm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence)
ls_hmm = _tskit.LsHmm(ref_ts_m._ll_tree_sequence, mu, rho, acgt_alleles=True)
ls_hmm.forward_matrix(query_h_m.T, fm)
ls_hmm.backward_matrix(query_h_m.T, fm.normalisation_factor, bm)
# Subset haplotypes.
ref_ts_genotyped = ref_ts.delete_sites(site_ids=ungenotyped_pos_idx)
ref_ts_ungenotyped = ref_ts.delete_sites(site_ids=genotyped_pos_idx)
ref_h_ungenotyped = ref_ts_ungenotyped.genotype_matrix()
query_h_genotyped = query_h[genotyped_pos_idx].astype(np.int32)
# Get forward and backward matrices from tree sequence.
fm = _tskit.CompressedMatrix(ref_ts_genotyped._ll_tree_sequence)
bm = _tskit.CompressedMatrix(ref_ts_genotyped._ll_tree_sequence)
ls_hmm = _tskit.LsHmm(
ref_ts_genotyped._ll_tree_sequence, mu, rho, acgt_alleles=True
)
ls_hmm.forward_matrix(query_h_genotyped.T, fm)
ls_hmm.backward_matrix(query_h_genotyped.T, fm.normalisation_factor, bm)
# Compute state probability matrix.
sm = compute_state_probability_matrix(fm.decode(), bm.decode())
# Interpolate allele probabilities.
allele_probs = interpolate_allele_probabilities(
sm, ref_h_x, genotyped_pos, imputed_pos
i_allele_probs = interpolate_allele_probabilities(
sm, ref_h_ungenotyped, genotyped_pos, ungenotyped_pos
)
# Get MAP alleles at imputed markers.
imputed_alleles, max_allele_probs = get_map_alleles(allele_probs)
# Get MAP alleles at ungenotyped markers.
imputed_alleles, max_allele_probs = get_map_alleles(i_allele_probs)
return (imputed_alleles, max_allele_probs)

0 comments on commit e4e922f

Please sign in to comment.