Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 21, 2024
1 parent 0827282 commit 7c1bfde
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 29 deletions.
4 changes: 2 additions & 2 deletions tests/lsbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ def get_ts_simple_n8_high_recomb(self, seed=42):
msprime.sim_ancestry(
samples=8,
ploidy=1,
sequence_length=10,
sequence_length=20,
recombination_rate=20.0,
random_seed=seed,
),
rate=5.0,
rate=0.2,
model=msprime.BinaryMutationModel(),
discrete_genome=False,
random_seed=seed,
Expand Down
5 changes: 2 additions & 3 deletions tests/test_api_vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
include_extreme_rates=True,
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)

V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(
n=n,
m=m,
G=G_vs,
s=s,
s=query,
e=e_vs,
r=r,
)
Expand All @@ -47,7 +46,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("num_samples", [4, 8, 16])
@pytest.mark.parametrize("num_samples", [8, 16])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors):
Expand Down
23 changes: 11 additions & 12 deletions tests/test_nontree_fb_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,45 +25,44 @@ def verify(
include_extreme_rates=include_extreme_rates,
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)

F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, s, e_vs, r, norm=True)
B_vs = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_vs, r)
F_vs, c_vs, ll_vs = fbd.forwards_ls_dip(n, m, G_vs, query, e_vs, r, norm=True)
B_vs = fbd.backwards_ls_dip(n, m, G_vs, query, e_vs, c_vs, r)
self.assertAllClose(np.sum(F_vs * B_vs, (1, 2)), np.ones(m))

F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop(
n, m, G_vs, s, e_vs, r, norm=True
n, m, G_vs, query, e_vs, r, norm=True
)
B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r)
B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, query, e_vs, c_tmp, r)
self.assertAllClose(np.sum(F_tmp * B_tmp, (1, 2)), np.ones(m))
self.assertAllClose(ll_vs, ll_tmp)

if not normalise:
F_tmp, c_tmp, ll_tmp = fbd.forwards_ls_dip(
n, m, G_vs, s, e_vs, r, norm=False
n, m, G_vs, query, e_vs, r, norm=False
)
if ll_tmp != -np.inf:
B_tmp = fbd.backwards_ls_dip(n, m, G_vs, s, e_vs, c_tmp, r)
B_tmp = fbd.backwards_ls_dip(n, m, G_vs, query, e_vs, c_tmp, r)
self.assertAllClose(
np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m)
)
self.assertAllClose(ll_vs, ll_tmp)

F_tmp, c_tmp, ll_tmp = fbd.forward_ls_dip_loop(
n, m, G_vs, s, e_vs, r, norm=False
n, m, G_vs, query, e_vs, r, norm=False
)
if ll_tmp != -np.inf:
B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, s, e_vs, c_tmp, r)
B_tmp = fbd.backward_ls_dip_loop(n, m, G_vs, query, e_vs, c_tmp, r)
self.assertAllClose(
np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m)
)
self.assertAllClose(ll_vs, ll_tmp)

F_tmp, ll_tmp = fbd.forward_ls_dip_starting_point(
n, m, G_vs, s, e_vs, r
n, m, G_vs, query, e_vs, r
)
if ll_tmp != -np.inf:
B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, s, e_vs, r)
B_tmp = fbd.backward_ls_dip_starting_point(n, m, G_vs, query, e_vs, r)
self.assertAllClose(
np.log10(np.sum(F_tmp * B_tmp, (1, 2))), ll_tmp * np.ones(m)
)
Expand All @@ -87,7 +86,7 @@ def test_ts_simple_n10_no_recomb(
include_extreme_rates=include_extreme_rates,
)

@pytest.mark.parametrize("num_samples", [4, 8, 16])
@pytest.mark.parametrize("num_samples", [8, 16])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
@pytest.mark.parametrize("normalise", [True, False])
Expand Down
23 changes: 11 additions & 12 deletions tests/test_nontree_vit_diploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
include_extreme_rates=True,
):
G_vs = core.convert_haplotypes_to_phased_genotypes(H_vs)
s = core.convert_haplotypes_to_unphased_genotypes(query)

V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, s, e_vs, r)
V_vs, P_vs, ll_vs = vd.forwards_viterbi_dip_low_mem(n, m, G_vs, query, e_vs, r)
path_vs = vd.backwards_viterbi_dip(m, V_vs, P_vs)
phased_path_vs = vd.get_phased_path(n, path_vs)
path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, s, e_vs, r)
path_ll_vs = vd.path_ll_dip(n, m, G_vs, phased_path_vs, query, e_vs, r)
self.assertAllClose(ll_vs, path_ll_vs)

(
Expand All @@ -34,7 +33,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
recombs_single,
recombs_double,
ll_tmp,
) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, s, e_vs, r)
) = vd.forwards_viterbi_dip_low_mem_no_pointer(n, m, G_vs, query, e_vs, r)
path_tmp = vd.backwards_viterbi_dip_no_pointer(
m,
V_argmaxes_tmp,
Expand All @@ -45,7 +44,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
V_tmp,
)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, query, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

Expand All @@ -54,29 +53,29 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
if num_ref_haps <= MAX_NUM_REF_HAPS:
# Run tests for the naive implementations.
V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive(
n, m, G_vs, s, e_vs, r
n, m, G_vs, query, e_vs, r
)
path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, query, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_low_mem(
n, m, G_vs, s, e_vs, r
n, m, G_vs, query, e_vs, r
)
path_tmp = vd.backwards_viterbi_dip(m, V_tmp, P_tmp)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, query, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vd.forwards_viterbi_dip_naive_vec(
n, m, G_vs, s, e_vs, r
n, m, G_vs, query, e_vs, r
)
path_tmp = vd.backwards_viterbi_dip(m, V_tmp[m - 1, :, :], P_tmp)
phased_path_tmp = vd.get_phased_path(n, path_tmp)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, s, e_vs, r)
path_ll_tmp = vd.path_ll_dip(n, m, G_vs, phased_path_tmp, query, e_vs, r)
self.assertAllClose(ll_tmp, path_ll_tmp)
self.assertAllClose(ll_vs, ll_tmp)

Expand All @@ -86,7 +85,7 @@ def test_ts_simple_n10_no_recomb(self, scale_mutation_rate, include_ancestors):
ts = self.get_ts_simple_n10_no_recomb()
self.verify(ts, scale_mutation_rate, include_ancestors)

@pytest.mark.parametrize("num_samples", [4, 8, 16])
@pytest.mark.parametrize("num_samples", [8, 16])
@pytest.mark.parametrize("scale_mutation_rate", [True, False])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_simple(self, num_samples, scale_mutation_rate, include_ancestors):
Expand Down

0 comments on commit 7c1bfde

Please sign in to comment.