From ba582b341f2150e110a122349481f1266593e7f3 Mon Sep 17 00:00:00 2001 From: szhan Date: Fri, 21 Jun 2024 10:52:25 +0100 Subject: [PATCH] Update tests --- lshmm/api.py | 3 ++- tests/test_api_fb_diploid.py | 16 ++++++++-------- tests/test_api_fb_haploid.py | 7 ++++--- tests/test_api_fb_haploid_multi.py | 8 ++++---- tests/test_api_vit_diploid.py | 10 +++++----- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/lshmm/api.py b/lshmm/api.py index 5f88078..ef76468 100644 --- a/lshmm/api.py +++ b/lshmm/api.py @@ -65,7 +65,6 @@ def check_inputs( # Check the reference panel. if not len(reference_panel.shape) == 2: - num_sites, num_ref_haps = reference_panel.shape err_msg = "Reference panel array has incorrect dimensions." raise ValueError(err_msg) @@ -79,6 +78,8 @@ def check_inputs( err_msg += "Only 0/1 encoding is supported in diploid mode." raise ValueError(err_msg) + num_sites, num_ref_haps = reference_panel.shape + # Check the queries. if query.shape[0] != ploidy: err_msg = "Query array has incorrect dimensions." diff --git a/tests/test_api_fb_diploid.py b/tests/test_api_fb_diploid.py index a5b485f..a092beb 100644 --- a/tests/test_api_fb_diploid.py +++ b/tests/test_api_fb_diploid.py @@ -8,16 +8,16 @@ class TestForwardBackwardDiploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 2 for n, m, H_vs, query, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=2, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) s = core.convert_haplotypes_to_unphased_genotypes(query) - num_alleles = core.get_num_alleles(H_vs, query) F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop( n=n, @@ -38,18 +38,18 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): r=r, ) F, c, ll = ls.forwards( - reference_panel=G_vs, - query=s, - num_alleles=num_alleles, + reference_panel=H_vs, + query=query, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, normalise=True, ) B = ls.backwards( - reference_panel=G_vs, - query=s, - num_alleles=num_alleles, + reference_panel=H_vs, + query=query, + ploidy=ploidy, normalisation_factor_from_forward=c, prob_recombination=r, prob_mutation=mu, diff --git a/tests/test_api_fb_haploid.py b/tests/test_api_fb_haploid.py index 3f5220c..058c575 100644 --- a/tests/test_api_fb_haploid.py +++ b/tests/test_api_fb_haploid.py @@ -8,9 +8,10 @@ class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, @@ -36,7 +37,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): F, c, ll = ls.forwards( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, @@ -45,7 +46,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): B = ls.backwards( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, normalisation_factor_from_forward=c, prob_recombination=r, prob_mutation=mu, diff --git a/tests/test_api_fb_haploid_multi.py b/tests/test_api_fb_haploid_multi.py index a50b299..86ad692 100644 --- a/tests/test_api_fb_haploid_multi.py +++ b/tests/test_api_fb_haploid_multi.py @@ -8,14 +8,14 @@ class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 1 for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=1, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): - num_alleles = core.get_num_alleles(H_vs, s) F_vs, c_vs, ll_vs = fbh.forwards_ls_hap( n=n, m=m, @@ -36,7 +36,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): F, c, ll = ls.forwards( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate, @@ -45,7 +45,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): B = ls.backwards( reference_panel=H_vs, query=s, - num_alleles=num_alleles, + ploidy=ploidy, normalisation_factor_from_forward=c, prob_recombination=r, prob_mutation=mu, diff --git a/tests/test_api_vit_diploid.py b/tests/test_api_vit_diploid.py index 48d67b7..ea6a47c 100644 --- a/tests/test_api_vit_diploid.py +++ b/tests/test_api_vit_diploid.py @@ -8,16 +8,16 @@ class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase): def verify(self, ts, scale_mutation_rate, include_ancestors): + ploidy = 2 for n, m, H_vs, query, e_vs, r, mu in self.get_examples_pars( ts, - ploidy=2, + ploidy=ploidy, scale_mutation_rate=scale_mutation_rate, include_ancestors=include_ancestors, include_extreme_rates=True, ): G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs) s = core.convert_haplotypes_to_unphased_genotypes(query) - num_alleles = core.get_num_alleles(H_vs, query) V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem( n=n, @@ -30,9 +30,9 @@ def verify(self, ts, scale_mutation_rate, include_ancestors): path_vs = vd.backwards_viterbi_dip(m=m, V_last=V_vs, P=P_vs) phased_path_vs = vd.get_phased_path(n=n, path=path_vs) path, ll = ls.viterbi( - reference_panel=G_vs, - query=s, - num_alleles=num_alleles, + reference_panel=H_vs, + query=query, + ploidy=ploidy, prob_recombination=r, prob_mutation=mu, scale_mutation_rate=scale_mutation_rate,