From 63037e45e46c753dec3acc3f976995c4ee701d6d Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 6 Sep 2024 15:50:35 +0100 Subject: [PATCH] Final test fixups --- tests/test_inference.py | 12 ++-- tests/test_tree_hueristics.py | 108 ---------------------------------- 2 files changed, 7 insertions(+), 113 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 846124a..5e88021 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -563,13 +563,14 @@ def test_first_day(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_db assert list(ts.mutations_time) == [0, 0, 0] assert list(ts.mutations_site) == [8632, 17816, 27786] sc2ts_md = ts.node(2).metadata["sc2ts"] - assert len(sc2ts_md["mutations"]) == 3 - for mut_md, mut in zip(sc2ts_md["mutations"], ts.mutations()): + hmm_md = sc2ts_md["hmm"][0] + assert hmm_md["direction"] == "forward" + assert len(hmm_md["mutations"]) == 3 + for mut_md, mut in zip(hmm_md["mutations"], ts.mutations()): assert mut_md["derived_state"] == mut.derived_state - assert mut_md["site_id"] == mut.site assert mut_md["site_position"] == ts.sites_position[mut.site] assert mut_md["inherited_state"] == ts.site(mut.site).ancestral_state - assert sc2ts_md["path"] == [{"left": 0, "parent": 1, "right": 29904}] + assert hmm_md["path"] == [{"left": 0, "parent": 1, "right": 29904}] assert sc2ts_md["qc"] == { "num_masked_sites": 133, "original_base_composition": { @@ -665,7 +666,8 @@ def test_exact_matches(self, fx_ts_map, node, strain, parent): assert x.flags == (tskit.NODE_IS_SAMPLE | sc2ts.core.NODE_IS_EXACT_MATCH) md = x.metadata assert md["strain"] == strain - sc2ts_md = md["sc2ts"] + sc2ts_md = md["sc2ts"]["hmm"][0] + assert sc2ts_md["direction"] == "forward" assert len(sc2ts_md["path"]) == 1 assert len(sc2ts_md["mutations"]) == 0 assert sc2ts_md["path"][0] == { diff --git a/tests/test_tree_hueristics.py b/tests/test_tree_hueristics.py index 3c5b630..2393fc2 100644 --- a/tests/test_tree_hueristics.py +++ b/tests/test_tree_hueristics.py @@ -297,114 +297,6 @@ def test_two_sites_reversion_and_shared(self): assert ts2.num_nodes == ts.num_nodes + 1 -class TestInsertRecombinants: - def test_no_recombination(self): - ts1 = tskit.Tree.generate_balanced(4, arity=4).tree_sequence - ts2 = sc2ts.inference.insert_recombinants(ts1) - ts1.tables.assert_equals(ts2.tables) - - def test_single_breakpoint_single_recombinant_no_mutations(self): - tables = tskit.TableCollection(10) - tables.nodes.add_row(flags=0, time=1) - tables.nodes.add_row(flags=0, time=1) - tables.nodes.add_row(flags=1, time=0) - tables.edges.add_row(0, 5, parent=0, child=2) - tables.edges.add_row(5, 10, parent=1, child=2) - ts = prepare(tables) - - ts2 = sc2ts.inference.insert_recombinants(ts) - assert_sequences_equal(ts, ts2) - assert ts2.num_mutations == 0 - assert ts2.num_nodes == ts.num_nodes + 1 - assert ts2.num_edges == ts.num_edges + 1 - assert_sequences_equal(ts, ts2) - - def test_single_breakpoint_two_recombinants_no_mutations(self): - tables = tskit.TableCollection(10) - tables.nodes.add_row(flags=0, time=1) - tables.nodes.add_row(flags=0, time=1) - tables.nodes.add_row(flags=1, time=0) - tables.nodes.add_row(flags=1, time=0) - tables.edges.add_row(0, 5, parent=0, child=2) - tables.edges.add_row(5, 10, parent=1, child=2) - tables.edges.add_row(0, 5, parent=0, child=3) - tables.edges.add_row(5, 10, parent=1, child=3) - ts = prepare(tables) - - ts2 = sc2ts.inference.insert_recombinants(ts) - assert_sequences_equal(ts, ts2) - assert ts2.num_mutations == 0 - assert ts2.num_nodes == ts.num_nodes + 1 - assert ts2.num_edges == ts.num_edges - assert_sequences_equal(ts, ts2) - - def test_single_breakpoint_single_recombinant_one_mutation(self): - tables = tskit.TableCollection(10) - tables.nodes.add_row(flags=0, time=1) - tables.nodes.add_row(flags=0, time=1) - tables.nodes.add_row(flags=1, time=0) - tables.edges.add_row(0, 5, parent=0, child=2) - tables.edges.add_row(5, 10, parent=1, child=2) - tables.sites.add_row(4, "A") - tables.mutations.add_row(site=0, node=2, derived_state="T") - ts = prepare(tables) - - ts2 = sc2ts.inference.insert_recombinants(ts) - md = ts2.node(3).metadata - assert md["mutations"] == [[2, [[0, "A", "T"]]]] - assert_sequences_equal(ts, ts2) - assert ts2.num_mutations == 1 - assert ts2.num_nodes == ts.num_nodes + 1 - assert ts2.num_edges == ts.num_edges + 1 - assert np.all(ts2.mutations_node == 3) - - def test_single_breakpoint_single_recombinant_two_mutations(self): - tables = tskit.TableCollection(10) - tables.nodes.add_row(flags=0, time=1) - tables.nodes.add_row(flags=0, time=1) - tables.nodes.add_row(flags=1, time=0) - tables.edges.add_row(0, 5, parent=0, child=2) - tables.edges.add_row(5, 10, parent=1, child=2) - tables.sites.add_row(4, "A") - tables.sites.add_row(5, "G") - tables.mutations.add_row(site=0, node=2, derived_state="T") - tables.mutations.add_row(site=1, node=2, derived_state="C") - ts = prepare(tables) - - ts2 = sc2ts.inference.insert_recombinants(ts) - md = ts2.node(3).metadata - assert md["mutations"] == [[2, [[0, "A", "T"], [1, "G", "C"]]]] - assert_sequences_equal(ts, ts2) - assert ts2.num_mutations == 2 - assert ts2.num_nodes == ts.num_nodes + 1 - assert ts2.num_edges == ts.num_edges + 1 - assert np.all(ts2.mutations_node == 3) - - def test_single_breakpoint_two_recombinants_different_mutations(self): - tables = tskit.TableCollection(10) - tables.sites.add_row(4, "A") - tables.sites.add_row(5, "G") - tables.nodes.add_row(flags=0, time=1) - tables.nodes.add_row(flags=0, time=1) - for j in [2, 3]: - tables.nodes.add_row(flags=1, time=0) - tables.edges.add_row(0, 5, parent=0, child=j) - tables.edges.add_row(5, 10, parent=1, child=j) - # Share the mutation at site 0 - tables.mutations.add_row(site=0, node=j, derived_state="T") - # Different mutations at site 1 - tables.mutations.add_row(site=1, node=2, derived_state="C") - tables.mutations.add_row(site=1, node=3, derived_state="T") - ts = prepare(tables) - - ts2 = sc2ts.inference.insert_recombinants(ts) - assert_sequences_equal(ts, ts2) - md = ts2.node(4).metadata - assert ts2.num_mutations == 3 - assert ts2.num_nodes == ts.num_nodes + 1 - assert ts2.num_edges == ts.num_edges - - class TestTrimBranches: def test_one_mutation_three_children(self): # 3.00┊ 6 ┊