diff --git a/python/tests/beagle.py b/python/tests/beagle.py index 17c94b6e1e..312c34e4ba 100644 --- a/python/tests/beagle.py +++ b/python/tests/beagle.py @@ -41,8 +41,6 @@ this implementation computes Equation 1 of BB2016. The functions used in an attempt to faithfully implement the BEAGLE algorithm are kept for documentation. """ -import logging - import numpy as np import _tskit @@ -526,19 +524,16 @@ def get_map_alleles(allele_probs): Assuming all biallelic sites, the output is an array of size x, where x is the number of imputed markers. + WARN: If the allele probabilities are equal, then allele 0 is arbitrarily chosen. + :param numpy.ndarray allele_probs: Interpolated allele probabilities. :return: Imputed alleles in the query haplotype. :rtype: numpy.ndarray """ assert not np.any(allele_probs < 0), "Allele probabilities have negative values." assert not np.any(np.isnan(allele_probs)), "Allele probabilities have NaN values." - x = allele_probs.shape[0] - imputed_alleles = np.zeros(x, dtype=int) - # TODO: Vectorise over the imputed markers - for i in np.arange(x): - if allele_probs[i, 0] == allele_probs[i, 1]: - logging.warning(f"Allele probabilities at imputed marker {i} are equal.") - imputed_alleles[i] = np.argmax(allele_probs[i, :]) + imputed_alleles = np.argmax(allele_probs, axis=1) + assert len(imputed_alleles) == allele_probs.shape[0] return imputed_alleles diff --git a/python/tests/beagle_numba.py b/python/tests/beagle_numba.py index 4fccaea87c..e292a77215 100644 --- a/python/tests/beagle_numba.py +++ b/python/tests/beagle_numba.py @@ -3,8 +3,6 @@ This is the numba-fied version of `beagle.py`. """ -import logging - import numpy as np from numba import njit @@ -324,17 +322,14 @@ def get_map_alleles(allele_probs): Assuming all biallelic sites, the output is an array of size x, where x is the number of imputed markers. + WARN: If the allele probabilities are equal, then allele 0 is arbitrarily chosen. + :param numpy.ndarray allele_probs: Interpolated allele probabilities. :return: Imputed alleles in the query haplotype. :rtype: numpy.ndarray """ - x = allele_probs.shape[0] - imputed_alleles = np.zeros(x, dtype=int) - # TODO: Vectorise over the imputed markers - for i in np.arange(x): - if allele_probs[i, 0] == allele_probs[i, 1]: - logging.warning(f"Allele probabilities at imputed marker {i} are equal.") - imputed_alleles[i] = np.argmax(allele_probs[i, :]) + imputed_alleles = np.argmax(allele_probs, axis=1) + assert len(imputed_alleles) == allele_probs.shape[0] return imputed_alleles