Skip to content

Commit

Permalink
Update tests for genotype matching
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 21, 2024
1 parent 7c1c600 commit 0813eaf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
46 changes: 23 additions & 23 deletions python/tests/test_genotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,10 +1326,11 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
H_check[i, :], H_check[i, :]
)

cm_d = ls_forward_tree(s[0, :], ts_check, r, mu)
Expand All @@ -1345,14 +1346,13 @@ def verify(self, ts):
self.assertAllClose(ll_tree, ll_mirror_tree_dict)

# Ensure that the decoded matrices are the same
flipped_G_check = np.flip(G_check, axis=0)
flipped_H_check = np.flip(H_check, axis=0)
flipped_s = np.flip(s, axis=1)
num_alleles = ls.core.get_num_alleles(flipped_G_check, flipped_s)

F_mirror_matrix, c, ll = ls.forwards(
flipped_G_check,
flipped_H_check,
flipped_s,
num_alleles=num_alleles,
ploidy=2,
prob_recombination=r_flip,
prob_mutation=np.flip(mu),
scale_mutation_rate=False,
Expand All @@ -1372,17 +1372,17 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
H_check[i, :], H_check[i, :]
)

num_alleles = ls.core.get_num_alleles(G_check, s)
F, c, ll = ls.forwards(
reference_panel=G_check,
reference_panel=H_check,
query=s,
num_alleles=num_alleles,
ploidy=2,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
Expand All @@ -1404,25 +1404,25 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
H_check[i, :], H_check[i, :]
)

num_alleles = ls.core.get_num_alleles(G_check, s)
F, c, ll = ls.forwards(
reference_panel=G_check,
reference_panel=H_check,
query=s,
num_alleles=num_alleles,
ploidy=2,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
B = ls.backwards(
reference_panel=G_check,
reference_panel=H_check,
query=s,
num_alleles=num_alleles,
ploidy=2,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
Expand Down Expand Up @@ -1465,26 +1465,26 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
H_check[i, :], H_check[i, :]
)
ts_check = ts.simplify(range(1, n + 1), filter_sites=False)

num_alleles = ls.core.get_num_alleles(G_check, s)
phased_path, ll = ls.viterbi(
reference_panel=G_check,
reference_panel=H_check,
query=s,
num_alleles=num_alleles,
ploidy=2,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
path_ll_matrix = ls.path_loglik(
reference_panel=G_check,
reference_panel=H_check,
query=s,
num_alleles=num_alleles,
ploidy=2,
path=phased_path,
prob_recombination=r,
prob_mutation=mu,
Expand All @@ -1498,9 +1498,9 @@ def verify(self, ts):
path_tree_dict = c_v.traceback()
# Work out the likelihood of the proposed path
path_ll_tree = ls.path_loglik(
reference_panel=G_check,
reference_panel=H_check,
query=s,
num_alleles=num_alleles,
ploidy=2,
path=np.transpose(path_tree_dict),
prob_recombination=r,
prob_mutation=mu,
Expand Down
3 changes: 3 additions & 0 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,7 @@ def check_viterbi(ts, h, recombination=None, mutation=None):
assert len(h) == m
if recombination is None:
recombination = np.zeros(ts.num_sites) + 1e-9
recombination[0] = 0.0
if mutation is None:
mutation = np.zeros(ts.num_sites)
precision = 22
Expand Down Expand Up @@ -1121,6 +1122,7 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None):
assert len(h) == m
if recombination is None:
recombination = np.zeros(ts.num_sites) + 1e-9
recombination[0] = 0.0
if mutation is None:
mutation = np.zeros(ts.num_sites)

Expand Down Expand Up @@ -1168,6 +1170,7 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
assert len(h) == m
if recombination is None:
recombination = np.zeros(ts.num_sites) + 1e-9
recombination[0] = 0.0
if mutation is None:
mutation = np.zeros(ts.num_sites)

Expand Down

0 comments on commit 0813eaf

Please sign in to comment.