Skip to content

Commit

Permalink
Merge pull request #110 from szhan/split_tests_by_algo
Browse files Browse the repository at this point in the history
Split tests by algorithm
  • Loading branch information
szhan authored Jun 18, 2024
2 parents 309aa18 + 6fb8bab commit 7bb3108
Show file tree
Hide file tree
Showing 10 changed files with 501 additions and 470 deletions.
99 changes: 0 additions & 99 deletions tests/test_api_diploid.py → tests/test_api_fb_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import lshmm as ls
import lshmm.core as core
import lshmm.fb_diploid as fbd
import lshmm.vit_diploid as vd


class TestForwardBackwardDiploid(lsbase.ForwardBackwardAlgorithmBase):
Expand Down Expand Up @@ -122,101 +121,3 @@ def test_ts_larger(self, scale_mutation_rate, include_ancestors):
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)


class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
for n, m, H_vs, query, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=2,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)
num_alleles = core.get_num_alleles(H_vs, query)

V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(
n=n,
m=m,
G=G_vs,
s=s,
e=e_vs,
r=r,
)
path_vs = vd.backwards_viterbi_dip(m=m, V_last=V_vs, P=P_vs)
phased_path_vs = vd.get_phased_path(n=n, path=path_vs)
path, ll = ls.viterbi(
reference_panel=G_vs,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
)

self.assertAllClose(ll_vs, ll)
self.assertAllClose(phased_path_vs, path)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
94 changes: 0 additions & 94 deletions tests/test_api_haploid.py → tests/test_api_fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import lshmm as ls
import lshmm.core as core
import lshmm.fb_haploid as fbh
import lshmm.vit_haploid as vh


class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase):
Expand Down Expand Up @@ -117,96 +116,3 @@ def test_ts_larger(self, scale_mutation_rate, include_ancestors):
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)


class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=1,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
num_alleles = core.get_num_alleles(H_vs, s)
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
n=n,
m=m,
H=H_vs,
s=s,
e=e_vs,
r=r,
)
path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs, P=P_vs)
path, ll = ls.viterbi(
reference_panel=H_vs,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
)
self.assertAllClose(ll_vs, ll)
self.assertAllClose(path_vs, path)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=46, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import lshmm as ls
import lshmm.core as core
import lshmm.fb_haploid as fbh
import lshmm.vit_haploid as vh


class TestForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase):
Expand Down Expand Up @@ -81,57 +80,3 @@ def test_ts_multiallelic_n8(self, scale_mutation_rate, include_ancestors):
def test_ts_multiallelic_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n16()
self.verify(ts, scale_mutation_rate, include_ancestors)


class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=1,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
num_alleles = core.get_num_alleles(H_vs, s)
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
n, m, H_vs, s, e_vs, r
)
path_vs = vh.backwards_viterbi_hap(m, V_vs, P_vs)
path_ll_hap = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r)
path, ll = ls.viterbi(
reference_panel=H_vs,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
)
self.assertAllClose(ll_vs, ll)
self.assertAllClose(ll_vs, path_ll_hap)
self.assertAllClose(path_vs, path)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n10_no_recomb(
self, scale_mutation_rate, include_ancestors
):
ts = self.get_ts_multiallelic_n10_no_recomb()
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n6()
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n8()
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_multiallelic_n16()
self.verify(ts, scale_mutation_rate, include_ancestors)
104 changes: 104 additions & 0 deletions tests/test_api_vit_diploid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest

from . import lsbase
import lshmm as ls
import lshmm.core as core
import lshmm.vit_diploid as vd


class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts, scale_mutation_rate, include_ancestors):
for n, m, H_vs, query, e_vs, r, mu in self.get_examples_pars(
ts,
ploidy=2,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
include_extreme_rates=True,
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)
num_alleles = core.get_num_alleles(H_vs, query)

V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(
n=n,
m=m,
G=G_vs,
s=s,
e=e_vs,
r=r,
)
path_vs = vd.backwards_viterbi_dip(m=m, V_last=V_vs, P=P_vs)
phased_path_vs = vd.get_phased_path(n=n, path=path_vs)
path, ll = ls.viterbi(
reference_panel=G_vs,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation_rate,
)

self.assertAllClose(ll_vs, ll)
self.assertAllClose(phased_path_vs, path)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n6(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n6()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n8_high_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n8_high_recomb()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple_n16(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n16()
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)

@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_larger(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_custom_pars(
ref_panel_size=45, length=1e5, mean_r=1e-5, mean_mu=1e-5
)
self.verify(
ts,
scale_mutation_rate=scale_mutation_rate,
include_ancestors=include_ancestors,
)
Loading

0 comments on commit 7bb3108

Please sign in to comment.