Skip to content

Commit

Permalink
Refactor and reorganise
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Apr 20, 2024
1 parent 5789876 commit 4f2b92b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 50 deletions.
56 changes: 28 additions & 28 deletions tests/test_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,7 @@
import lshmm.vit_haploid as vh


class TestMethodsHaploid(lsbase.FBAlgorithmBase):
def test_simple_n10_no_recombination(self):
ts = self.get_simple_n10_no_recombination()
self.verify(ts)

def test_simple_n6(self):
ts = self.get_simple_n6()
self.verify(ts)

def test_simple_n8(self):
ts = self.get_simple_n8()
self.verify(ts)

def test_simple_n8_high_recombination(self):
ts = self.get_simple_n8_high_recombination()
self.verify(ts)

def test_simple_n16(self):
ts = self.get_simple_n16()
self.verify(ts)

class TestForwardBackwardHaploid(lsbase.FBAlgorithmBase):
def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts):
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r)
Expand All @@ -40,8 +20,6 @@ def verify(self, ts):
F, c, ll = ls.forwards(H_vs, s, r, mu)
B = ls.backwards(H_vs, s, c, r, mu)


class TestMethodsDiploid(lsbase.FBAlgorithmBase):
def test_simple_n10_no_recombination(self):
ts = self.get_simple_n10_no_recombination()
self.verify(ts)
Expand All @@ -62,6 +40,8 @@ def test_simple_n16(self):
ts = self.get_simple_n16()
self.verify(ts)


class TestForwardBackwardDiploid(lsbase.FBAlgorithmBase):
def verify(self, ts):
for n, m, G_vs, s, e_vs, r, mu in self.get_examples_pars_diploid(ts):
F_vs, c_vs, ll_vs = fbd.forward_ls_dip_loop(
Expand All @@ -74,8 +54,6 @@ def verify(self, ts):
self.assertAllClose(B, B_vs)
self.assertAllClose(ll_vs, ll)


class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def test_simple_n10_no_recombination(self):
ts = self.get_simple_n10_no_recombination()
self.verify(ts)
Expand All @@ -96,6 +74,8 @@ def test_simple_n16(self):
ts = self.get_simple_n16()
self.verify(ts)


class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts):
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
Expand All @@ -106,8 +86,6 @@ def verify(self, ts):
self.assertAllClose(ll_vs, ll)
self.assertAllClose(path_vs, path)


class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase):
def test_simple_n10_no_recombination(self):
ts = self.get_simple_n10_no_recombination()
self.verify(ts)
Expand All @@ -124,10 +102,12 @@ def test_simple_n8_high_recombination(self):
ts = self.get_simple_n8_high_recombination()
self.verify(ts)

def test_simple_n_16(self):
def test_simple_n16(self):
ts = self.get_simple_n16()
self.verify(ts)


class TestViterbiDiploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts):
for n, m, G_vs, s, e_vs, r, mu in self.get_examples_pars_diploid(ts):
V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r)
Expand All @@ -136,3 +116,23 @@ def verify(self, ts):
path, ll = ls.viterbi(G_vs, s, r, p_mutation=mu)
self.assertAllClose(ll_vs, ll)
self.assertAllClose(phased_path_vs, path)

def test_simple_n10_no_recombination(self):
ts = self.get_simple_n10_no_recombination()
self.verify(ts)

def test_simple_n6(self):
ts = self.get_simple_n6()
self.verify(ts)

def test_simple_n8(self):
ts = self.get_simple_n8()
self.verify(ts)

def test_simple_n8_high_recombination(self):
ts = self.get_simple_n8_high_recombination()
self.verify(ts)

def test_simple_n_16(self):
ts = self.get_simple_n16()
self.verify(ts)
38 changes: 18 additions & 20 deletions tests/test_API_multiallelic.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,10 @@
from . import lsbase
import lshmm as ls
import lshmm.fb_diploid as fbd
import lshmm.fb_haploid as fbh
import lshmm.vit_diploid as vd
import lshmm.vit_haploid as vh


class TestMethodsHaploid(lsbase.FBAlgorithmBase):
def test_multiallelic_n10_no_recombination(self):
ts = self.get_multiallelic_n10_no_recombination()
self.verify(ts)

def test_multiallelic_n6(self):
ts = self.get_multiallelic_n6()
self.verify(ts)

def test_multiallelic_n8(self):
ts = self.get_multiallelic_n8()
self.verify(ts)

def test_multiallelic_n16(self):
ts = self.get_multiallelic_n16()
self.verify(ts)

def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts):
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r)
Expand All @@ -48,8 +30,6 @@ def verify(self, ts):
self.assertAllClose(B, B_vs)
self.assertAllClose(ll_vs, ll)


class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def test_multiallelic_n10_no_recombination(self):
ts = self.get_multiallelic_n10_no_recombination()
self.verify(ts)
Expand All @@ -66,6 +46,8 @@ def test_multiallelic_n16(self):
ts = self.get_multiallelic_n16()
self.verify(ts)


class TestViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts):
for n, m, H_vs, s, e_vs, r, mu in self.get_examples_pars_haploid(ts):
V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
Expand All @@ -77,3 +59,19 @@ def verify(self, ts):
self.assertAllClose(ll_vs, ll)
self.assertAllClose(ll_vs, path_ll_hap)
self.assertAllClose(path_vs, path)

def test_multiallelic_n10_no_recombination(self):
ts = self.get_multiallelic_n10_no_recombination()
self.verify(ts)

def test_multiallelic_n6(self):
ts = self.get_multiallelic_n6()
self.verify(ts)

def test_multiallelic_n8(self):
ts = self.get_multiallelic_n8()
self.verify(ts)

def test_multiallelic_n16(self):
ts = self.get_multiallelic_n16()
self.verify(ts)
4 changes: 2 additions & 2 deletions tests/test_LS_haploid_diploid.py → tests/test_non_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import lshmm.vit_haploid as vh


class TestNonTreeMethodsHaploid(lsbase.FBAlgorithmBase):
class TestNonTreeForwardBackwardHaploid(lsbase.FBAlgorithmBase):
def verify(self, n, m, H_vs, s, e_vs, r):
F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False)
B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r)
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_larger(self):
self.verify(n, m, H_vs, s, e_vs, r)


class TestNonTreeMethodsDiploid(lsbase.FBAlgorithmBase):
class TestNonTreeForwardBackwardDiploid(lsbase.FBAlgorithmBase):
def verify(self, n, m, G_vs, s, e_vs, r):
F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True)
B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r)
Expand Down

0 comments on commit 4f2b92b

Please sign in to comment.