diff --git a/lshmm/api.py b/lshmm/api.py index 5b88181..69ff87d 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -17,6 +17,9 @@ path_ll_hap, ) +MISSING = -1 +NONCOPY = -2 + EQUAL_BOTH_HOM = 4 UNEQUAL_BOTH_HOM = 0 BOTH_HET = 7 @@ -27,18 +30,27 @@ def check_alleles(alleles, m): """ - Checks the specified allele list and returns a list of lists - of alleles of length num_sites. - If alleles is a 1D list of strings, assume that this list is used - for each site and return num_sites copies of this list. - Otherwise, raise a ValueError if alleles is not a list of length - num_sites. + Checks the specified allele list and returns a list of allele lists of length m. + + If alleles is a 1D list of strings, assume that this list is used for each site + and return num_sites copies of this list. Otherwise, raise a ValueError + if alleles is not a list of length m. + + Note MISSING and NONCOPY values are excluded from the counts. + + :param list alleles: A list of lists of alleles or strings. + :param int m: Number of sites. + :return: An array of number of distinct alleles at each site. + :rtype: numpy.ndarray """ if isinstance(alleles[0], str): return np.int8([len(alleles) for _ in range(m)]) if len(alleles) != m: - raise ValueError("Malformed alleles list") - n_alleles = np.int8([(len(alleles_site)) for alleles_site in alleles]) + raise ValueError("Number of alleles list is not equal to number of sites.") + exclusion_set = np.array([MISSING, NONCOPY]) + n_alleles = np.zeros(m, dtype=np.int8) + for i in range(m): + n_alleles[i] = np.sum(~np.isin(np.unique(alleles[i]), exclusion_set)) return n_alleles @@ -132,12 +144,11 @@ def set_emission_probabilities( # Check alleles should go in here, and modify e before passing to the algorithm # If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel. if alleles is None: - n_alleles = np.int8( - [ - len(np.unique(np.append(reference_panel[j, :], query[:, j]))) - for j in range(reference_panel.shape[0]) - ] - ) + exclusion_set = np.array([MISSING, NONCOPY]) + n_alleles = np.zeros(m, dtype=np.int8) + for j in range(reference_panel.shape[0]): + uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j])) + n_alleles[j] = np.sum(~np.isin(uniq_alleles, exclusion_set)) else: n_alleles = check_alleles(alleles, m) diff --git a/tests/test_API_noncopy_2.py b/tests/test_API_noncopy_2.py index f54f24f..da651a8 100644 --- a/tests/test_API_noncopy_2.py +++ b/tests/test_API_noncopy_2.py @@ -38,7 +38,7 @@ def get_ancestral_haplotypes(self, ts): tsp = tables.tree_sequence() B = tsp.genotype_matrix().T - # Modified. Originally, this was filled with -2 by default. + # Modified. Originally, this was filled with NONCOPY by default. A = np.full((ts.num_nodes, ts.num_sites), NONCOPY, dtype=np.int8) for edge in ts.edges(): start = bisect.bisect_left(sites, edge.left) @@ -55,22 +55,19 @@ def get_ancestral_haplotypes(self, ts): def example_haplotypes(self, ts, num_random=10, seed=42): H = self.get_ancestral_haplotypes(ts) - #H = ts.genotype_matrix() s = H[:, 0].reshape(1, H.shape[0]) H = H[:, 1:] - # TODO: Figure out why tests fail when MISSING is in a query. haplotypes = [s, H[:, -1].reshape(1, H.shape[0])] s_tmp = s.copy() - #s_tmp[0, -1] = MISSING # End - s_tmp[0, 0] = MISSING # Beginning - #haplotypes.append(s_tmp) + s_tmp[0, -1] = MISSING # End + haplotypes.append(s_tmp) s_tmp = s.copy() s_tmp[0, ts.num_sites // 2] = MISSING - #haplotypes.append(s_tmp) + haplotypes.append(s_tmp) s_tmp = s.copy() s_tmp[0, :] = MISSING - #haplotypes.append(s_tmp) + haplotypes.append(s_tmp) return H, haplotypes @@ -116,7 +113,6 @@ def _get_n_states(H, s): exclude_set = np.array([MISSING, NONCOPY]) for j in range(m): proper_set = np.unique(np.append(H[j, :], s[:, j])) - #n_states[j] = len(proper_set) n_states[j] = np.sum(~np.isin(proper_set, exclude_set)) assert np.all(n_states >= 0) return n_states