diff --git a/python/tests/beagle.py b/python/tests/beagle.py index e742a5ebaf..6b28161946 100644 --- a/python/tests/beagle.py +++ b/python/tests/beagle.py @@ -426,9 +426,7 @@ def compute_state_probability_matrix(fm, bm): return sm -def interpolate_allele_probabilities( - sm, ref_h, alleles, genotyped_pos, ungenotyped_pos -): +def interpolate_allele_probabilities(sm, ref_h, genotyped_pos, ungenotyped_pos): """ Compute the interpolated allele probabilities at ungenotyped markers of a query haplotype following Equation 1 of BB2016. @@ -444,19 +442,16 @@ def interpolate_allele_probabilities( :param numpy.ndarray sm: HMM state probability matrix at genotyped markers. :param numpy.ndarray ref_h: Reference haplotypes subsetted to ungenotyped markers. - :param numpy.ndarray alleles: Alleles (ancestral/derived or ACGT encoding). :param numpy.ndarray genotyped_pos: Site positions at genotyped markers. :param numpy.ndarray ungenotyped_pos: Site positions at ungenotyped markers. :return: Interpolated allele probabilities. :rtype: numpy.ndarray """ + alleles = np.arange(4) # ACGT encoding assert not np.any(sm < 0), "HMM state probability matrix has negative values." assert not np.any(np.isnan(sm)), "HMM state probability matrix has NaN values." m = sm.shape[0] - 1 h = sm.shape[1] - assert np.all( - np.isin(ref_h, alleles) - ), f"Reference haplotypes have alleles absent in {alleles}." assert m == len(genotyped_pos) x = len(ungenotyped_pos) assert (x, h) == ref_h.shape @@ -600,7 +595,7 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6, debug=False): # Interpolate allele probabilities at ungenotyped markers. alleles = np.arange(4) # ACGT i_allele_probs = interpolate_allele_probabilities( - sm, ref_h_ungenotyped, alleles, genotyped_pos, ungenotyped_pos + sm, ref_h_ungenotyped, genotyped_pos, ungenotyped_pos ) if debug: print("Interpolated allele probabilities") diff --git a/python/tests/beagle_numba.py b/python/tests/beagle_numba.py index bff4dcfcaa..df4d68ddb6 100644 --- a/python/tests/beagle_numba.py +++ b/python/tests/beagle_numba.py @@ -217,9 +217,7 @@ def compute_state_probability_matrix(fm, bm): @njit -def interpolate_allele_probabilities( - sm, ref_h, alleles, genotyped_pos, ungenotyped_pos -): +def interpolate_allele_probabilities(sm, ref_h, genotyped_pos, ungenotyped_pos): """ Compute the interpolated allele probabilities at ungenotyped markers of a query haplotype following Equation 1 of BB2016. @@ -235,12 +233,12 @@ def interpolate_allele_probabilities( :param numpy.ndarray sm: HMM state probability matrix at genotyped markers. :param numpy.ndarray ref_h: Reference haplotypes subsetted to imputed markers. - :param numpy.ndarray alleles: Alleles (ancestral/derived or ACGT encoding). :param numpy.ndarray genotyped_pos: Site positions at genotyped markers. :param numpy.ndarray ungenotyped_pos: Site positions at ungenotyped markers. :return: Interpolated allele probabilities. :rtype: numpy.ndarray """ + alleles = np.arange(4) # ACGT encoding x = len(ungenotyped_pos) weights, marker_interval_start = get_weights(genotyped_pos, ungenotyped_pos) p = np.zeros((x, len(alleles)), dtype=np.float64) @@ -297,7 +295,6 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6): :return: Imputed alleles and their associated probabilities. :rtype: tuple(numpy.ndarray, numpy.ndarray) """ - alleles = np.arange(4) # ACGT encoding # Indices of markers. genotyped_pos_idx = np.where(query_h != -1)[0] ungenotyped_pos_idx = np.where(query_h == -1)[0] @@ -321,7 +318,7 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6): ref_h_ungenotyped = ref_h[ungenotyped_pos_idx, :] # Interpolate allele probabilities at ungenotyped markers. i_allele_probs = interpolate_allele_probabilities( - sm, ref_h_ungenotyped, alleles, genotyped_pos, ungenotyped_pos + sm, ref_h_ungenotyped, genotyped_pos, ungenotyped_pos ) # Get MAP alleles at ungenotyped markers. imputed_alleles, max_allele_probs = get_map_alleles(i_allele_probs) @@ -329,7 +326,6 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6): def run_tsimpute(ref_ts, query_h, pos): - alleles = np.arange(4) # ACGT encoding # Prepare marker positions. genotyped_site_ids = np.where(query_h != -1)[0] genotyped_pos = pos[genotyped_site_ids] @@ -354,7 +350,7 @@ def run_tsimpute(ref_ts, query_h, pos): sm = compute_state_probability_matrix(fm.decode(), bm.decode()) # Interpolate allele probabilities. allele_probs = interpolate_allele_probabilities( - sm, ref_h_x, alleles, genotyped_pos, imputed_pos + sm, ref_h_x, genotyped_pos, imputed_pos ) # Get MAP alleles at imputed markers. imputed_alleles, max_allele_probs = get_map_alleles(allele_probs)