diff --git a/sc2ts/inference.py b/sc2ts/inference.py index a6f032d..1670210 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -565,6 +565,8 @@ def extend( hmm_cost_threshold = 5 if min_group_size is None: min_group_size = 10 + if retrospective_window is None: + retrospective_window = 30 check_base_ts(base_ts) logger.info( diff --git a/tests/test_inference.py b/tests/test_inference.py index ba09948..846124a 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -10,15 +10,14 @@ class TestSolveNumMismatches: - @pytest.mark.parametrize( ["k", "expected_rho"], - [(2, 0.02295918), (3, 0.00327988), (4, 0.00046855), (1000, 0)], + [(2, 0.0001904), (3, 2.50582e-06), (4, 3.297146e-08), (1000, 0)], ) def test_examples(self, k, expected_rho): - rho, mu = sc2ts.solve_num_mismatches(k, num_sites=2) - assert mu[0] == 0.125 - nt.assert_almost_equal(rho[0], expected_rho) + mu, rho = sc2ts.solve_num_mismatches(k) + assert mu == 0.0125 + nt.assert_almost_equal(rho, expected_rho) class TestInitialTs: @@ -204,10 +203,21 @@ def test_two_samples_one_mutation_one_filtered(self, tmp_path): class TestMatchTsinfer: - def match_tsinfer(self, samples, ts, **kwargs): + def match_tsinfer(self, samples, ts, mirror_coordinates=False, **kwargs): sc2ts.inference.match_tsinfer( - samples=samples, ts=ts, num_mismatches=1000, **kwargs + samples=samples, + ts=ts, + mu=0.125, + rho=0, + mirror_coordinates=mirror_coordinates, + **kwargs, ) + if mirror_coordinates: + # Quick hack to make the tests here work, as they use + # attributes defined by the forward path. + for sample in samples: + sample.forward_path = sample.reverse_path + sample.forward_mutations = sample.reverse_mutations @pytest.mark.parametrize("mirror", [False, True]) def test_match_reference(self, mirror): @@ -406,8 +416,8 @@ def test_n_samples_metadata(self): strain=strain, date=date, metadata={f"x{j}": j, f"y{j}": list(range(j))}, - path=[(0, ts.sequence_length, 1)], - mutations=[], + forward_path=[(0, ts.sequence_length, 1)], + forward_mutations=[], ) ) @@ -722,13 +732,13 @@ def test_exact_matches( samples = sc2ts.preprocess( [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store ) - # FIXME - mu = 0.125 + mu, rho = sc2ts.solve_num_mismatches(num_mismatches) sc2ts.match_tsinfer( samples=samples, ts=ts, - num_mismatches=num_mismatches, - likelihood_threshold = mu**num_mismatches - 1e-12, + mu=mu, + rho=rho, + likelihood_threshold=mu**num_mismatches - 1e-12, num_threads=0, ) s = samples[0] @@ -757,12 +767,13 @@ def test_one_mismatch( samples = sc2ts.preprocess( [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store ) + mu, rho = sc2ts.solve_num_mismatches(num_mismatches) sc2ts.match_tsinfer( samples=samples, ts=ts, - num_mismatches=num_mismatches, - # FIXME - likelihood_threshold=0.12499999, + mu=mu, + rho=rho, + likelihood_threshold=mu - 1e-5, num_threads=0, ) s = samples[0] @@ -785,11 +796,12 @@ def test_two_mismatches( samples = sc2ts.preprocess( [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store ) - mu = 0.125 + mu, rho = sc2ts.solve_num_mismatches(num_mismatches) sc2ts.match_tsinfer( samples=samples, ts=ts, - num_mismatches=num_mismatches, + mu=mu, + rho=rho, likelihood_threshold=mu**2 - 1e-12, num_threads=0, )