From 0813eaf5d4ac38ea71cdf9d2cecc75f757f6815b Mon Sep 17 00:00:00 2001 From: szhan Date: Fri, 21 Jun 2024 18:29:11 +0100 Subject: [PATCH] Update tests for genotype matching --- python/tests/test_genotype_matching.py | 46 ++++++++++++------------- python/tests/test_haplotype_matching.py | 3 ++ 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/python/tests/test_genotype_matching.py b/python/tests/test_genotype_matching.py index ce35b90b18..9829b17d89 100644 --- a/python/tests/test_genotype_matching.py +++ b/python/tests/test_genotype_matching.py @@ -1326,10 +1326,11 @@ def verify(self, ts): ts_check, mapping = ts.simplify( range(1, n + 1), filter_sites=False, map_nodes=True ) + H_check = ts_check.genotype_matrix() G_check = np.zeros((m, n, n)) for i in range(m): G_check[i, :, :] = np.add.outer( - ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :] + H_check[i, :], H_check[i, :] ) cm_d = ls_forward_tree(s[0, :], ts_check, r, mu) @@ -1345,14 +1346,13 @@ def verify(self, ts): self.assertAllClose(ll_tree, ll_mirror_tree_dict) # Ensure that the decoded matrices are the same - flipped_G_check = np.flip(G_check, axis=0) + flipped_H_check = np.flip(H_check, axis=0) flipped_s = np.flip(s, axis=1) - num_alleles = ls.core.get_num_alleles(flipped_G_check, flipped_s) F_mirror_matrix, c, ll = ls.forwards( - flipped_G_check, + flipped_H_check, flipped_s, - num_alleles=num_alleles, + ploidy=2, prob_recombination=r_flip, prob_mutation=np.flip(mu), scale_mutation_rate=False, @@ -1372,17 +1372,17 @@ def verify(self, ts): ts_check, mapping = ts.simplify( range(1, n + 1), filter_sites=False, map_nodes=True ) + H_check = ts_check.genotype_matrix() G_check = np.zeros((m, n, n)) for i in range(m): G_check[i, :, :] = np.add.outer( - ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :] + H_check[i, :], H_check[i, :] ) - num_alleles = ls.core.get_num_alleles(G_check, s) F, c, ll = ls.forwards( - reference_panel=G_check, + reference_panel=H_check, query=s, - num_alleles=num_alleles, + ploidy=2, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=False, @@ -1404,25 +1404,25 @@ def verify(self, ts): ts_check, mapping = ts.simplify( range(1, n + 1), filter_sites=False, map_nodes=True ) + H_check = ts_check.genotype_matrix() G_check = np.zeros((m, n, n)) for i in range(m): G_check[i, :, :] = np.add.outer( - ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :] + H_check[i, :], H_check[i, :] ) - num_alleles = ls.core.get_num_alleles(G_check, s) F, c, ll = ls.forwards( - reference_panel=G_check, + reference_panel=H_check, query=s, - num_alleles=num_alleles, + ploidy=2, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=False, ) B = ls.backwards( - reference_panel=G_check, + reference_panel=H_check, query=s, - num_alleles=num_alleles, + ploidy=2, normalisation_factor_from_forward=c, prob_recombination=r, prob_mutation=mu, @@ -1465,26 +1465,26 @@ def verify(self, ts): ts_check, mapping = ts.simplify( range(1, n + 1), filter_sites=False, map_nodes=True ) + H_check = ts_check.genotype_matrix() G_check = np.zeros((m, n, n)) for i in range(m): G_check[i, :, :] = np.add.outer( - ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :] + H_check[i, :], H_check[i, :] ) ts_check = ts.simplify(range(1, n + 1), filter_sites=False) - num_alleles = ls.core.get_num_alleles(G_check, s) phased_path, ll = ls.viterbi( - reference_panel=G_check, + reference_panel=H_check, query=s, - num_alleles=num_alleles, + ploidy=2, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=False, ) path_ll_matrix = ls.path_loglik( - reference_panel=G_check, + reference_panel=H_check, query=s, - num_alleles=num_alleles, + ploidy=2, path=phased_path, prob_recombination=r, prob_mutation=mu, @@ -1498,9 +1498,9 @@ def verify(self, ts): path_tree_dict = c_v.traceback() # Work out the likelihood of the proposed path path_ll_tree = ls.path_loglik( - reference_panel=G_check, + reference_panel=H_check, query=s, - num_alleles=num_alleles, + ploidy=2, path=np.transpose(path_tree_dict), prob_recombination=r, prob_mutation=mu, diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 813daf1cb2..0b00e8cb00 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1062,6 +1062,7 @@ def check_viterbi(ts, h, recombination=None, mutation=None): assert len(h) == m if recombination is None: recombination = np.zeros(ts.num_sites) + 1e-9 + recombination[0] = 0.0 if mutation is None: mutation = np.zeros(ts.num_sites) precision = 22 @@ -1121,6 +1122,7 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None): assert len(h) == m if recombination is None: recombination = np.zeros(ts.num_sites) + 1e-9 + recombination[0] = 0.0 if mutation is None: mutation = np.zeros(ts.num_sites) @@ -1168,6 +1170,7 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): assert len(h) == m if recombination is None: recombination = np.zeros(ts.num_sites) + 1e-9 + recombination[0] = 0.0 if mutation is None: mutation = np.zeros(ts.num_sites)