Skip to content

Commit

Permalink
Make mu and rho input arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Sep 20, 2023
1 parent 5053898 commit d2b4444
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
29 changes: 16 additions & 13 deletions python/tests/beagle_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,22 +324,25 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6):
return (imputed_alleles, max_allele_probs)


def run_tsimpute(ref_ts, query_h, pos):
def run_tsimpute(ref_ts, query_h, pos, mu, rho):
"""
TODO: Document this function.
TODO: Put this function elsewhere.
"""
# Set indices of markers.
genotyped_pos_idx = np.where(query_h != -1)[0]
ungenotyped_pos_idx = np.where(query_h == -1)[0]
genotyped_site_idx = np.where(query_h != -1)[0]
ungenotyped_site_idx = np.where(query_h == -1)[0]
# Set site positions of markers.
genotyped_pos = pos[genotyped_pos_idx]
ungenotyped_pos = pos[ungenotyped_pos_idx]
# Set parametrization, which differs between BEAGLE and tsinfer.
mu = get_mismatch_prob(genotyped_pos, miscall_rate=1e-8)
rho = get_switch_prob(genotyped_pos, h=ref_ts.num_samples, ne=10_000.0)
rho /= 1e5
genotyped_site_pos = pos[genotyped_site_idx]
ungenotyped_site_pos = pos[ungenotyped_site_idx]
# Get parameters at genotyped markers.
mu = mu[genotyped_site_idx]
rho = rho[genotyped_site_idx]
# Subset haplotypes.
ref_ts_genotyped = ref_ts.delete_sites(site_ids=ungenotyped_pos_idx)
ref_ts_ungenotyped = ref_ts.delete_sites(site_ids=genotyped_pos_idx)
ref_ts_genotyped = ref_ts.delete_sites(site_ids=ungenotyped_site_idx)
ref_ts_ungenotyped = ref_ts.delete_sites(site_ids=genotyped_site_idx)
ref_h_ungenotyped = ref_ts_ungenotyped.genotype_matrix()
query_h_genotyped = query_h[genotyped_pos_idx].astype(np.int32)
query_h_genotyped = query_h[genotyped_site_idx].astype(np.int32)
# Get forward and backward matrices from tree sequence.
fm = _tskit.CompressedMatrix(ref_ts_genotyped._ll_tree_sequence)
bm = _tskit.CompressedMatrix(ref_ts_genotyped._ll_tree_sequence)
Expand All @@ -352,7 +355,7 @@ def run_tsimpute(ref_ts, query_h, pos):
sm = compute_state_probability_matrix(fm.decode(), bm.decode())
# Interpolate allele probabilities.
i_allele_probs = interpolate_allele_probabilities(
sm, ref_h_ungenotyped, genotyped_pos, ungenotyped_pos
sm, ref_h_ungenotyped, genotyped_site_pos, ungenotyped_site_pos
)
# Get MAP alleles at ungenotyped markers.
imputed_alleles, max_allele_probs = get_map_alleles(i_allele_probs)
Expand Down
7 changes: 6 additions & 1 deletion python/tests/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,16 +640,21 @@ def parse_matrix(csv_text):
],
)
def test_tsimpute(input_ref, input_query):
"""
Test whether the outputs of tsimpute and Python BEAGLE implementation match.
"""
# TODO: Compare interpolated allele probabilities.
toy_ref_ts = get_toy_data() # Same for both cases
pos = toy_ref_ts.sites_position
num_query_haps = input_query.shape[0]
mu = np.zeros(len(pos), dtype=np.float32) + 1e-8
rho = np.zeros(len(pos), dtype=np.float32) + 1e-8
for i in np.arange(num_query_haps):
imputed_alleles, _ = tests.beagle.run_beagle(
input_ref, input_query[i], pos, miscall_rate=0.0001, ne=10.0
)
imputed_alleles_ts, _ = tests.beagle_numba.run_tsimpute(
toy_ref_ts, input_query[i], pos
toy_ref_ts, input_query[i], pos, mu, rho
)
np.testing.assert_array_equal(imputed_alleles, imputed_alleles_ts)

Expand Down

0 comments on commit d2b4444

Please sign in to comment.