Skip to content

Commit

Permalink
Fixup tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Sep 6, 2024
1 parent c247543 commit 6b73b6e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
2 changes: 2 additions & 0 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ def extend(
hmm_cost_threshold = 5
if min_group_size is None:
min_group_size = 10
if retrospective_window is None:
retrospective_window = 30

check_base_ts(base_ts)
logger.info(
Expand Down
48 changes: 30 additions & 18 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@


class TestSolveNumMismatches:

@pytest.mark.parametrize(
["k", "expected_rho"],
[(2, 0.02295918), (3, 0.00327988), (4, 0.00046855), (1000, 0)],
[(2, 0.0001904), (3, 2.50582e-06), (4, 3.297146e-08), (1000, 0)],
)
def test_examples(self, k, expected_rho):
rho, mu = sc2ts.solve_num_mismatches(k, num_sites=2)
assert mu[0] == 0.125
nt.assert_almost_equal(rho[0], expected_rho)
mu, rho = sc2ts.solve_num_mismatches(k)
assert mu == 0.0125
nt.assert_almost_equal(rho, expected_rho)


class TestInitialTs:
Expand Down Expand Up @@ -204,10 +203,21 @@ def test_two_samples_one_mutation_one_filtered(self, tmp_path):


class TestMatchTsinfer:
def match_tsinfer(self, samples, ts, **kwargs):
def match_tsinfer(self, samples, ts, mirror_coordinates=False, **kwargs):
sc2ts.inference.match_tsinfer(
samples=samples, ts=ts, num_mismatches=1000, **kwargs
samples=samples,
ts=ts,
mu=0.125,
rho=0,
mirror_coordinates=mirror_coordinates,
**kwargs,
)
if mirror_coordinates:
# Quick hack to make the tests here work, as they use
# attributes defined by the forward path.
for sample in samples:
sample.forward_path = sample.reverse_path
sample.forward_mutations = sample.reverse_mutations

@pytest.mark.parametrize("mirror", [False, True])
def test_match_reference(self, mirror):
Expand Down Expand Up @@ -406,8 +416,8 @@ def test_n_samples_metadata(self):
strain=strain,
date=date,
metadata={f"x{j}": j, f"y{j}": list(range(j))},
path=[(0, ts.sequence_length, 1)],
mutations=[],
forward_path=[(0, ts.sequence_length, 1)],
forward_mutations=[],
)
)

Expand Down Expand Up @@ -722,13 +732,13 @@ def test_exact_matches(
samples = sc2ts.preprocess(
[fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store
)
# FIXME
mu = 0.125
mu, rho = sc2ts.solve_num_mismatches(num_mismatches)
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
likelihood_threshold = mu**num_mismatches - 1e-12,
mu=mu,
rho=rho,
likelihood_threshold=mu**num_mismatches - 1e-12,
num_threads=0,
)
s = samples[0]
Expand Down Expand Up @@ -757,12 +767,13 @@ def test_one_mismatch(
samples = sc2ts.preprocess(
[fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store
)
mu, rho = sc2ts.solve_num_mismatches(num_mismatches)
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
# FIXME
likelihood_threshold=0.12499999,
mu=mu,
rho=rho,
likelihood_threshold=mu - 1e-5,
num_threads=0,
)
s = samples[0]
Expand All @@ -785,11 +796,12 @@ def test_two_mismatches(
samples = sc2ts.preprocess(
[fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store
)
mu = 0.125
mu, rho = sc2ts.solve_num_mismatches(num_mismatches)
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
mu=mu,
rho=rho,
likelihood_threshold=mu**2 - 1e-12,
num_threads=0,
)
Expand Down

0 comments on commit 6b73b6e

Please sign in to comment.