Skip to content

Commit

Permalink
Implement excluding NONCOPY and MISSING and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 4, 2024
1 parent 1b3ed70 commit 4f2864f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
39 changes: 25 additions & 14 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
path_ll_hap,
)

MISSING = -1
NONCOPY = -2

EQUAL_BOTH_HOM = 4
UNEQUAL_BOTH_HOM = 0
BOTH_HET = 7
Expand All @@ -27,18 +30,27 @@

def check_alleles(alleles, m):
"""
Checks the specified allele list and returns a list of lists
of alleles of length num_sites.
If alleles is a 1D list of strings, assume that this list is used
for each site and return num_sites copies of this list.
Otherwise, raise a ValueError if alleles is not a list of length
num_sites.
Checks the specified allele list and returns a list of allele lists of length m.
If alleles is a 1D list of strings, assume that this list is used for each site
and return num_sites copies of this list. Otherwise, raise a ValueError
if alleles is not a list of length m.
Note MISSING and NONCOPY values are excluded from the counts.
:param list alleles: A list of lists of alleles or strings.
:param int m: Number of sites.
:return: An array of number of distinct alleles at each site.
:rtype: numpy.ndarray
"""
if isinstance(alleles[0], str):
return np.int8([len(alleles) for _ in range(m)])
if len(alleles) != m:
raise ValueError("Malformed alleles list")
n_alleles = np.int8([(len(alleles_site)) for alleles_site in alleles])
raise ValueError("Number of alleles list is not equal to number of sites.")
exclusion_set = np.array([MISSING, NONCOPY])
n_alleles = np.zeros(m, dtype=np.int8)
for i in range(m):
n_alleles[i] = np.sum(~np.isin(np.unique(alleles[i]), exclusion_set))
return n_alleles


Expand Down Expand Up @@ -132,12 +144,11 @@ def set_emission_probabilities(
# Check alleles should go in here, and modify e before passing to the algorithm
# If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel.
if alleles is None:
n_alleles = np.int8(
[
len(np.unique(np.append(reference_panel[j, :], query[:, j])))
for j in range(reference_panel.shape[0])
]
)
exclusion_set = np.array([MISSING, NONCOPY])
n_alleles = np.zeros(m, dtype=np.int8)
for j in range(reference_panel.shape[0]):
uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j]))
n_alleles[j] = np.sum(~np.isin(uniq_alleles, exclusion_set))
else:
n_alleles = check_alleles(alleles, m)

Expand Down
14 changes: 5 additions & 9 deletions tests/test_API_noncopy_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_ancestral_haplotypes(self, ts):
tsp = tables.tree_sequence()
B = tsp.genotype_matrix().T

# Modified. Originally, this was filled with -2 by default.
# Modified. Originally, this was filled with NONCOPY by default.
A = np.full((ts.num_nodes, ts.num_sites), NONCOPY, dtype=np.int8)
for edge in ts.edges():
start = bisect.bisect_left(sites, edge.left)
Expand All @@ -55,22 +55,19 @@ def get_ancestral_haplotypes(self, ts):

def example_haplotypes(self, ts, num_random=10, seed=42):
H = self.get_ancestral_haplotypes(ts)
#H = ts.genotype_matrix()
s = H[:, 0].reshape(1, H.shape[0])
H = H[:, 1:]

# TODO: Figure out why tests fail when MISSING is in a query.
haplotypes = [s, H[:, -1].reshape(1, H.shape[0])]
s_tmp = s.copy()
#s_tmp[0, -1] = MISSING # End
s_tmp[0, 0] = MISSING # Beginning
#haplotypes.append(s_tmp)
s_tmp[0, -1] = MISSING # End
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, ts.num_sites // 2] = MISSING
#haplotypes.append(s_tmp)
haplotypes.append(s_tmp)
s_tmp = s.copy()
s_tmp[0, :] = MISSING
#haplotypes.append(s_tmp)
haplotypes.append(s_tmp)

return H, haplotypes

Expand Down Expand Up @@ -116,7 +113,6 @@ def _get_n_states(H, s):
exclude_set = np.array([MISSING, NONCOPY])
for j in range(m):
proper_set = np.unique(np.append(H[j, :], s[:, j]))
#n_states[j] = len(proper_set)
n_states[j] = np.sum(~np.isin(proper_set, exclude_set))
assert np.all(n_states >= 0)
return n_states
Expand Down

0 comments on commit 4f2864f

Please sign in to comment.