diff --git a/lshmm/api.py b/lshmm/api.py index 5b88181..9e1fd9d 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -24,21 +24,38 @@ REF_HET_OBS_HOM = 2 MISSING_INDEX = 3 +MISSING = -1 + def check_alleles(alleles, m): """ - Checks the specified allele list and returns a list of lists - of alleles of length num_sites. - If alleles is a 1D list of strings, assume that this list is used - for each site and return num_sites copies of this list. - Otherwise, raise a ValueError if alleles is not a list of length - num_sites. + Check a list of allele lists (or strings representing alleles) at m sites, and + return a list of counts of distinct alleles at the m sites. + + If alleles is a list of strings, then each string represents distinct alleles + at a site, and each character in a string represents a distinct allele. + It is assumed that MISSING is not encoded in these strings. + + Note MISSING values in allele lists are excluded from the counts. + + :param list alleles: A list of lists of alleles (or strings). + :param int m: Number of sites. + :return: An array of counts of distinct alleles at each site. + :rtype: numpy.ndarray """ + num_sites = m + if len(alleles) != num_sites: + err_msg = "Number of allele lists (or strings) is not equal to number of sites." + raise ValueError(err_msg) + # Process string encoding of distinct alleles. if isinstance(alleles[0], str): return np.int8([len(alleles) for _ in range(m)]) - if len(alleles) != m: - raise ValueError("Malformed alleles list") - n_alleles = np.int8([(len(alleles_site)) for alleles_site in alleles]) + # Otherwise, process allele lists. + exclusion_set = np.array([MISSING]) + n_alleles = np.zeros(num_sites, dtype=np.int8) + for i in range(num_sites): + uniq_alleles = np.unique(alleles[i]) + n_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) return n_alleles @@ -132,12 +149,11 @@ def set_emission_probabilities( # Check alleles should go in here, and modify e before passing to the algorithm # If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel. if alleles is None: - n_alleles = np.int8( - [ - len(np.unique(np.append(reference_panel[j, :], query[:, j]))) - for j in range(reference_panel.shape[0]) - ] - ) + exclusion_set = np.array([MISSING]) + n_alleles = np.zeros(m, dtype=np.int8) + for j in range(reference_panel.shape[0]): + uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j])) + n_alleles[j] = np.sum(~np.isin(uniq_alleles, exclusion_set)) else: n_alleles = check_alleles(alleles, m) diff --git a/tests/test_API.py b/tests/test_API.py index 77b8700..6c0678c 100644 --- a/tests/test_API.py +++ b/tests/test_API.py @@ -86,15 +86,24 @@ def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): n = H.shape[1] m = ts.get_num_sites() + def _get_num_alleles(ref_haps, query): + assert ref_haps.shape[0] == query.shape[1] + num_sites = ref_haps.shape[0] + num_alleles = np.zeros(num_sites, dtype=np.int8) + exclusion_set = np.array([MISSING]) + for i in range(num_sites): + uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) + num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) + assert np.all(num_alleles >= 0), "Number of alleles cannot be zero." + return num_alleles + # Here we have equal mutation and recombination r = np.zeros(m) + 0.01 mu = np.zeros(m) + 0.01 r[0] = 0 for s in haplotypes: - n_alleles = np.int8( - [len(np.unique(np.append(H[j, :], s[:, j]))) for j in range(m)] - ) + n_alleles = _get_num_alleles(H, s) e = self.haplotype_emission( mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation ) @@ -106,9 +115,7 @@ def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): for s, r, mu in itertools.product(haplotypes, rs, mus): r[0] = 0 - n_alleles = np.int8( - [len(np.unique(np.append(H[j, :], s[:, j]))) for j in range(H.shape[0])] - ) + n_alleles = _get_num_alleles(H, s) e = self.haplotype_emission( mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation ) diff --git a/tests/test_API_multiallelic.py b/tests/test_API_multiallelic.py index 3fce478..338f117 100644 --- a/tests/test_API_multiallelic.py +++ b/tests/test_API_multiallelic.py @@ -74,6 +74,17 @@ def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): n = H.shape[1] m = ts.get_num_sites() + def _get_num_alleles(ref_haps, query): + assert ref_haps.shape[0] == query.shape[1] + num_sites = ref_haps.shape[0] + num_alleles = np.zeros(num_sites, dtype=np.int8) + exclusion_set = np.array([MISSING]) + for i in range(num_sites): + uniq_alleles = np.unique(np.append(ref_haps[i, :], query[:, i])) + num_alleles[i] = np.sum(~np.isin(uniq_alleles, exclusion_set)) + assert np.all(num_alleles >= 0), "Number of alleles cannot be zero." + return num_alleles + # Here we have equal mutation and recombination r = np.zeros(m) + 0.01 mu = np.zeros(m) + 0.01 @@ -82,9 +93,7 @@ def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): for s in haplotypes: # Must be calculated from the genotype matrix because we can now get back mutations that # result in the number of alleles being higher than the number of alleles in the reference panel. - n_alleles = np.int8( - [len(np.unique(np.append(H[j, :], s[:, j]))) for j in range(m)] - ) + n_alleles = _get_num_alleles(H, s) e = self.haplotype_emission( mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation ) @@ -100,9 +109,7 @@ def example_parameters_haplotypes(self, ts, seed=42, scale_mutation=True): for s, r, mu in itertools.product(haplotypes, rs, mus): r[0] = 0 - n_alleles = np.int8( - [len(np.unique(np.append(H[j, :], s[:, j]))) for j in range(H.shape[0])] - ) + n_alleles = _get_num_alleles(H, s) e = self.haplotype_emission( mu, m, n_alleles, scale_mutation_based_on_n_alleles=scale_mutation )