From 7c1c600f576fccd5bd6c64305bd6af9ea54808f8 Mon Sep 17 00:00:00 2001 From: szhan Date: Fri, 21 Jun 2024 11:10:53 +0100 Subject: [PATCH] Update tests for haploid case --- python/tests/test_haplotype_matching.py | 27 +++++++++---------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 17e12bcf27..813daf1cb2 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -927,11 +927,10 @@ def verify(self, ts): # Ensure that the decoded matrices are the same flipped_H = np.flip(H, axis=0) flipped_s = np.flip(s, axis=1) - num_alleles = ls.core.get_num_alleles(flipped_H, flipped_s) F_mirror_matrix, c, ll = ls.forwards( reference_panel=flipped_H, query=flipped_s, - num_alleles=num_alleles, + ploidy=1, prob_recombination=r_flip, prob_mutation=np.flip(mu), scale_mutation_rate=False, @@ -953,11 +952,10 @@ def verify(self, ts): # Warning from lshmm: # Passed a vector of mutation rates, but rescaling each mutation # rate conditional on the number of alleles - num_alleles = ls.core.get_num_alleles(H, s) F, c, ll = ls.forwards( reference_panel=H, query=s, - num_alleles=num_alleles, + ploidy=1, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation, @@ -983,11 +981,10 @@ class TestForwardBackwardTree(FBAlgorithmBase): def verify(self, ts): for n, H, s, r, mu in self.example_parameters_haplotypes(ts): - num_alleles = ls.core.get_num_alleles(H, s) F, c, ll = ls.forwards( reference_panel=H, query=s, - num_alleles=num_alleles, + ploidy=1, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=False, @@ -995,7 +992,7 @@ def verify(self, ts): B = ls.backwards( reference_panel=H, query=s, - num_alleles=num_alleles, + ploidy=1, normalisation_factor_from_forward=c, prob_recombination=r, prob_mutation=mu, @@ -1030,11 +1027,10 @@ class TestTreeViterbiHap(VitAlgorithmBase): def verify(self, ts): for n, H, s, r, mu in self.example_parameters_haplotypes(ts): - num_alleles = ls.core.get_num_alleles(H, s) path, ll = ls.viterbi( reference_panel=H, query=s, - num_alleles=num_alleles, + ploidy=1, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=False, @@ -1050,7 +1046,7 @@ def verify(self, ts): ll_check = ls.path_loglik( reference_panel=H, query=s, - num_alleles=num_alleles, + ploidy=1, path=path_tree, prob_recombination=r, prob_mutation=mu, @@ -1072,12 +1068,11 @@ def check_viterbi(ts, h, recombination=None, mutation=None): G = ts.genotype_matrix() s = h.reshape(1, m) - num_alleles = ls.core.get_num_alleles(G, s) path, ll = ls.viterbi( reference_panel=G, query=s, - num_alleles=num_alleles, + ploidy=1, prob_recombination=recombination, prob_mutation=mutation, scale_mutation_rate=False, @@ -1095,7 +1090,7 @@ def check_viterbi(ts, h, recombination=None, mutation=None): ll_check = ls.path_loglik( reference_panel=G, query=s, - num_alleles=num_alleles, + ploidy=1, path=path_tree, prob_recombination=recombination, prob_mutation=mutation, @@ -1131,12 +1126,11 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None): G = ts.genotype_matrix() s = h.reshape(1, m) - num_alleles = ls.core.get_num_alleles(G, s) F, c, ll = ls.forwards( reference_panel=G, query=s, - num_alleles=num_alleles, + ploidy=1, prob_recombination=recombination, prob_mutation=mutation, scale_mutation_rate=False, @@ -1179,12 +1173,11 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): G = ts.genotype_matrix() s = h.reshape(1, m) - num_alleles = ls.core.get_num_alleles(G, s) B = ls.backwards( reference_panel=G, query=h.reshape(1, m), - num_alleles=num_alleles, + ploidy=1, normalisation_factor_from_forward=forward_cm.normalisation_factor, prob_recombination=recombination, prob_mutation=mutation,