Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 21, 2024
1 parent 24e1c5d commit ba582b3
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 21 deletions.
3 changes: 2 additions & 1 deletion lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def check_inputs(

# Check the reference panel.
if not len(reference_panel.shape) == 2:
num_sites, num_ref_haps = reference_panel.shape
err_msg = "Reference panel array has incorrect dimensions."
raise ValueError(err_msg)

Expand All @@ -79,6 +78,8 @@ def check_inputs(
err_msg += "Only 0/1 encoding is supported in diploid mode."
raise ValueError(err_msg)

num_sites, num_ref_haps = reference_panel.shape

# Check the queries.
if query.shape[0] != ploidy:
err_msg = "Query array has incorrect dimensions."
Expand Down
16 changes: 8 additions & 8 deletions tests/test_api_fb_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

class TestForwardBackwardDiploid(lsbase.ForwardBackwardAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 2
for n, m, H_vs, query, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=2,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)
num_alleles = core.get_num_alleles(H_vs, query)

F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop(
n=n,
Expand All @@ -38,18 +38,18 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
r=r,
)
F, c, ll = ls.forwards(
reference_panel=G_vs,
query=s,
num_alleles=num_alleles,
reference_panel=H_vs,
query=query,
ploidy=ploidy,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
normalise=True,
)
B = ls.backwards(
reference_panel=G_vs,
query=s,
num_alleles=num_alleles,
reference_panel=H_vs,
query=query,
ploidy=ploidy,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_api_fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 1
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=1,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
Expand All @@ -36,7 +37,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
F, c, ll = ls.forwards(
reference_panel=H_vs,
query=s,
num_alleles=num_alleles,
ploidy=ploidy,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
Expand All @@ -45,7 +46,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
B = ls.backwards(
reference_panel=H_vs,
query=s,
num_alleles=num_alleles,
ploidy=ploidy,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_api_fb_haploid_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 1
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=1,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
num_alleles = core.get_num_alleles(H_vs, s)
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(
n=n,
m=m,
Expand All @@ -36,7 +36,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
F, c, ll = ls.forwards(
reference_panel=H_vs,
query=s,
num_alleles=num_alleles,
ploidy=ploidy,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
Expand All @@ -45,7 +45,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
B = ls.backwards(
reference_panel=H_vs,
query=s,
num_alleles=num_alleles,
ploidy=ploidy,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_api_vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
ploidy = 2
for n, m, H_vs, query, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=2,
ploidy=ploidy,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)
num_alleles = core.get_num_alleles(H_vs, query)

V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(
n=n,
Expand All @@ -30,9 +30,9 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
path_vs = vd.backwards_viterbi_dip(m=m, V_last=V_vs, P=P_vs)
phased_path_vs = vd.get_phased_path(n=n, path=path_vs)
path, ll = ls.viterbi(
reference_panel=G_vs,
query=s,
num_alleles=num_alleles,
reference_panel=H_vs,
query=query,
ploidy=ploidy,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
Expand Down

0 comments on commit ba582b3

Please sign in to comment.