Skip to content

Commit

Permalink
Minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Sep 13, 2023
1 parent e38675d commit efbe1fe
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
19 changes: 10 additions & 9 deletions python/tests/beagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,45 +585,45 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6, debug=False):
if debug:
print("Reference haplotypes subsetted to genotyped markers")
print(ref_h_genotyped)
assert ref_h_genotyped.shape == (m, h)
assert (m, h) == ref_h_genotyped.shape
# Subset the query haplotype to genotyped markers
query_h_genotyped = query_h[genotyped_pos_idx]
if debug:
print("Query haplotype subsetted to genotyped markers")
print(query_h_genotyped)
assert len(query_h_genotyped) == m
assert m == len(query_h_genotyped)
# Set mismatch probabilities at genotyped markers
mu = get_mismatch_prob(genotyped_pos, miscall_rate=miscall_rate)
if debug:
print("Mismatch probabilities")
print(mu)
assert len(mu) == m
assert m == len(mu)
# Set switch probabilities at genotyped markers
rho = get_switch_prob(genotyped_pos, h, ne=ne)
if debug:
print("Switch probabilities")
print(rho)
assert len(rho) == m
assert m == len(rho)
# Compute forward probability matrix at genotyped markers
fm = compute_forward_probability_matrix(ref_h_genotyped, query_h_genotyped, rho, mu)
if debug:
print("Forward probability matrix")
print(fm)
assert fm.shape == (m, h)
assert (m, h) == fm.shape
# Compute backward probability matrix at genotyped markers
bm = compute_backward_probability_matrix(
ref_h_genotyped, query_h_genotyped, rho, mu
)
if debug:
print("Backward probability matrix")
print(bm)
assert bm.shape == (m, h)
assert (m, h) == bm.shape
# Compute HMM state probability matrix at genotyped markers
sm = compute_state_probability_matrix(fm, bm, ref_h_genotyped, query_h_genotyped)
if debug:
print("HMM state probability matrix")
print(sm)
assert sm.shape == (m + 1, h)
assert (m + 1, h) == sm.shape
# Subset the reference haplotypes to imputed markers
ref_h_imputed = ref_h[imputed_pos_idx, :]
# Interpolate allele probabilities at imputed markers
Expand All @@ -633,13 +633,14 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6, debug=False):
if debug:
print("Interpolated allele probabilities")
print(i_allele_probs)
assert i_allele_probs.shape == (x, 2)
# TODO: Allow for multiallelic sites.
assert (x, 2) == i_allele_probs.shape
# Get MAP alleles at imputed markers
imputed_alleles = get_map_alleles(i_allele_probs)
if debug:
print("Imputed alleles")
print(imputed_alleles)
assert len(imputed_alleles) == x
assert x == len(imputed_alleles)
return (imputed_alleles, i_allele_probs)


Expand Down
19 changes: 10 additions & 9 deletions python/tests/beagle_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,37 +367,38 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6):
assert m + x == len(pos)
# Subset the reference haplotypes to genotyped markers
ref_h_genotyped = ref_h[genotyped_pos_idx, :]
assert ref_h_genotyped.shape == (m, h)
assert (m, h) == ref_h_genotyped.shape
# Subset the query haplotype to genotyped markers
query_h_genotyped = query_h[genotyped_pos_idx]
assert len(query_h_genotyped) == m
assert m == len(query_h_genotyped)
# Set mismatch probabilities at genotyped markers
mu = get_mismatch_prob(genotyped_pos, miscall_rate=miscall_rate)
assert len(mu) == m
assert m == len(mu)
# Set switch probabilities at genotyped markers
rho = get_switch_prob(genotyped_pos, h, ne=ne)
assert len(rho) == m
assert m == len(rho)
# Compute forward probability matrix at genotyped markers
fm = compute_forward_probability_matrix(ref_h_genotyped, query_h_genotyped, rho, mu)
assert fm.shape == (m, h)
assert (m, h) == fm.shape
# Compute backward probability matrix at genotyped markers
bm = compute_backward_probability_matrix(
ref_h_genotyped, query_h_genotyped, rho, mu
)
assert bm.shape == (m, h)
assert (m, h) == bm.shape
# Compute HMM state probability matrix at genotyped markers
sm = compute_state_probability_matrix(fm, bm, ref_h_genotyped, query_h_genotyped)
assert sm.shape == (m + 1, h)
assert (m + 1, h) == sm.shape
# Subset the reference haplotypes to imputed markers
ref_h_imputed = ref_h[imputed_pos_idx, :]
# Interpolate allele probabilities at imputed markers
i_allele_probs = interpolate_allele_probabilities(
sm, ref_h_imputed, genotyped_pos, imputed_pos
)
assert i_allele_probs.shape == (x, 2)
# TODO: Allow for multiallelic sites.
assert (x, 2) == i_allele_probs.shape
# Get MAP alleles at imputed markers
imputed_alleles = get_map_alleles(i_allele_probs)
assert len(imputed_alleles) == x
assert x == len(imputed_alleles)
return (imputed_alleles, i_allele_probs)


Expand Down

0 comments on commit efbe1fe

Please sign in to comment.