From 082942ee35c57239470a6e9536b2bd351b4a2a38 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Mon, 16 Sep 2024 23:43:29 +0100 Subject: [PATCH 1/4] Revert to older recurrent mutation handling approach --- sc2ts/core.py | 1 + sc2ts/inference.py | 70 +++++++++++++++++++++++++++++----------------- 2 files changed, 46 insertions(+), 25 deletions(-) diff --git a/sc2ts/core.py b/sc2ts/core.py index faf02e1..0a8c82c 100644 --- a/sc2ts/core.py +++ b/sc2ts/core.py @@ -18,6 +18,7 @@ NODE_IS_REVERSION_PUSH = 1 << 22 NODE_IS_RECOMBINANT = 1 << 23 NODE_IS_EXACT_MATCH = 1 << 24 +NODE_IS_IMMEDIATE_REVERSION_PARENT = 1 << 25 __version__ = "undefined" diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 518cdef..2b1b30d 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -704,7 +704,7 @@ def match_path_ts(group): node_id = add_sample_to_tables(sample, tables, group_id=group.sample_hash) tables.edges.add_row(0, tables.sequence_length, parent=0, child=node_id) for mut in sample.mutations: - if (mut.site_id, mut.derived_state) in group.reversions: + if (mut.site_id, mut.derived_state) in group.immediate_reversions: # We don't include any of the marked reversions so that they # aren't used in tree building. continue @@ -757,7 +757,7 @@ class SampleGroup: samples: List = None path: List = None - reversions: List = None + immediate_reversions: List = None sample_hash: str = None date_count: dict = dataclasses.field(default_factory=collections.Counter) @@ -786,7 +786,8 @@ def summary(self): f"Group {self.sample_hash} {len(self.samples)} samples " f"({dict(self.date_count)}) " f"attaching at {path_summary(self.path)} and " - f"reversions={self.reversions}; strains={self.strains}" + f"immediate_reversions={self.immediate_reversions}; " + f"strains={self.strains}" ) @@ -808,12 +809,12 @@ def add_matching_results( num_samples = 0 for sample in match_db.get(where_clause): path = tuple(sample.path) - reversions = tuple( + immediate_reversions = tuple( (mut.site_id, mut.derived_state) for mut in sample.mutations if mut.is_immediate_reversion ) - grouped_matches[(path, reversions)].append(sample) + grouped_matches[(path, immediate_reversions)].append(sample) num_samples += 1 if num_samples == 0: @@ -821,8 +822,8 @@ def add_matching_results( return ts groups = [ - SampleGroup(samples, path, reversions) - for (path, reversions), samples in grouped_matches.items() + SampleGroup(samples, path, immediate_reversions) + for (path, immediate_reversions), samples in grouped_matches.items() ] logger.info(f"Got {len(groups)} groups for {num_samples} samples") @@ -1191,6 +1192,11 @@ def push_up_reversions(ts, samples): tables.edges.add_row(0, ts.sequence_length, parent=w, child=sample) tables.edges.add_row(0, ts.sequence_length, parent=grandparent, child=w) + # Move any non-reversions mutations above the parent to the new node. + for mut in np.where(ts.mutations_node == parent)[0]: + row = tables.mutations[mut] + if row.site not in sites: + tables.mutations[mut] = row.replace(node=w, time=w_time) for site in sites: # Delete the reversion mutations above the sample muts = np.where( @@ -1198,11 +1204,6 @@ def push_up_reversions(ts, samples): )[0] assert len(muts) == 1 mutations_to_delete.extend(muts) - # Move any non-reversions mutations above the parent to the new node. - for mut in np.where(ts.mutations_node == parent)[0]: - row = tables.mutations[mut] - if row.site not in sites: - tables.mutations[mut] = row.replace(node=w, time=w_time) num_del_mutations = len(mutations_to_delete) num_new_nodes = len(tables.nodes) - ts.num_nodes @@ -1611,7 +1612,6 @@ def attach_tree( epsilon=None, ): attach_path = group.path - reversions = group.reversions if epsilon is None: epsilon = 1e-6 # In time units of days ago @@ -1624,7 +1624,11 @@ def attach_tree( raise ValueError("Incompatible sequence length") tree = child_ts.first() - condition = np.any(child_ts.mutations_node == tree.root) or len(attach_path) > 1 + condition = ( + np.any(child_ts.mutations_node == tree.root) + or len(attach_path) > 1 + or len(group.immediate_reversions) > 0 + ) if condition: child_ts = add_root_edge(child_ts) tree = child_ts.first() @@ -1677,15 +1681,6 @@ def attach_tree( parent_tables.edges.add_row( seg.left, seg.right, parent=seg.parent, child=node_id_map[child] ) - # Add the reversion mutations over the attach nodes. These will be picked up - # by the reversion push code below, so no point in setting metadata. - for site_id, derived_state in reversions: - parent_tables.mutations.add_row( - site=site_id, - node=node_id_map[child], - derived_state=derived_state, - time=node_time[child], - ) # Add the mutations. for site in child_ts.sites(): @@ -1702,6 +1697,31 @@ def attach_tree( }, ) + if len(group.immediate_reversions) > 0: + # Flag the node as an NODE_IS_IMMEDIATE_REVERSION_PARENT. + # It should be removed, along with the mutations we're adding here by + # push_up_reversions in all cases except recombinants (which we've wussed + # out on handling properly). + node = tree.children(tree.root)[0] + assert tree.num_children(tree.root) == 1 + u = node_id_map[node] + row = parent_tables.nodes[u] + parent_tables.nodes[u] = row.replace( + flags=core.NODE_IS_IMMEDIATE_REVERSION_PARENT + ) + # print("attaching reversions at ", node, node_id_map[node]) + # print(child_ts.draw_text()) + for site_id, derived_state in group.immediate_reversions: + parent_tables.mutations.add_row( + site=site_id, + node=u, + derived_state=derived_state, + time=node_time[node], + metadata={ + "sc2ts": {"type": "match_reversion", "group_id": group.sample_hash} + }, + ) + if len(attach_path) > 1: # Update the recombinant flags also. u = node_id_map[tree.children(tree.root)[0]] @@ -1711,7 +1731,7 @@ def attach_tree( return [node_id_map[u] for u in tree.children(tree.root)] -def add_root_edge(ts): +def add_root_edge(ts, flags=0): """ Add another node and edge above the root and rescale time back to 0-1. @@ -1721,7 +1741,7 @@ def add_root_edge(ts): root = ts.first().root # FIXME this is bogus. We should be doing all the time scaling by numbers # of mutations. - new_root = tables.nodes.add_row(time=1.25) + new_root = tables.nodes.add_row(time=1.25, flags=flags) tables.edges.add_row(0, ts.sequence_length, parent=new_root, child=root) tables.nodes.time /= np.max(tables.nodes.time) return tables.tree_sequence() From 3ebf4dd50d457d331e757f82f1e0c32919c45613 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 17 Sep 2024 10:19:28 +0100 Subject: [PATCH 2/4] Remove unary nodes from reversion push --- sc2ts/core.py | 2 +- sc2ts/inference.py | 80 +++++++++++++++++++++++++++++++++++++++-- tests/test_inference.py | 3 +- 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/sc2ts/core.py b/sc2ts/core.py index 0a8c82c..494d67c 100644 --- a/sc2ts/core.py +++ b/sc2ts/core.py @@ -18,7 +18,7 @@ NODE_IS_REVERSION_PUSH = 1 << 22 NODE_IS_RECOMBINANT = 1 << 23 NODE_IS_EXACT_MATCH = 1 << 24 -NODE_IS_IMMEDIATE_REVERSION_PARENT = 1 << 25 +NODE_IS_IMMEDIATE_REVERSION_MARKER = 1 << 25 __version__ = "undefined" diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 2b1b30d..8361de8 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -871,12 +871,20 @@ def add_matching_results( md["masked_samples"] += int(site_masked_samples[int(site.position)]) tables.sites.append(site.replace(metadata=md)) + # NOTE: Doing the parsimony hueristic updates really is complicated a lot + # by doing all of group batches together. It should be simpler if we reason + # about *one* tree being added at a time. We might get hit by a high-cost + # in terms of creating tree sequence objects over and again, but maybe we + # can do less sorting by thinking more clearly about how to add the edges. + # If we only add edges pointing to one parent at a time, we should be able + # to just insert them into the middle of the table? tables.sort() tables.build_index() tables.compute_mutation_parents() ts = tables.tree_sequence() ts = push_up_reversions(ts, attach_nodes) ts = coalesce_mutations(ts, attach_nodes) + ts = delete_immediate_reversion_nodes(ts, attach_nodes) return ts @@ -969,6 +977,7 @@ def update_tables(tables, edges_to_delete, mutations_to_delete): # Updating the mutations is a real faff, and the only way I # could get it to work is by setting the time values. This should # be easier... + # NOTE: this should be easier to do now that we have the "keep_rows" methods mutations_to_keep = np.ones(len(tables.mutations), dtype=bool) mutations_to_keep[mutations_to_delete] = False tables.mutations.replace_with(tables.mutations[mutations_to_keep]) @@ -1108,6 +1117,8 @@ def coalesce_mutations(ts, samples=None): return update_tables(tables, edges_to_delete, mutations_to_delete) +# NOTE: "samples" is a bad name here, this is actually the set of attach_nodes +# that we get from making a local tree from a group. def push_up_reversions(ts, samples): # We depend on mutations having a time below. assert np.all(np.logical_not(np.isnan(ts.mutations_time))) @@ -1214,6 +1225,65 @@ def push_up_reversions(ts, samples): return update_tables(tables, edges_to_delete, mutations_to_delete) +def is_full_span(tree, u): + """ + Returns true if the edge in which the specified node is a child + covers the full span of the tree sequence. + """ + ts = tree.tree_sequence + e = tree.edge(u) + assert e != -1 + edge = ts.edge(e) + return edge.left == 0 and edge.right == ts.sequence_length + + +def delete_immediate_reversion_nodes(ts, attach_nodes): + tree = ts.first() + nodes_to_delete = [] + for u in attach_nodes: + # If a node is a node inserted to track the immediate reversions + # shared by all the samples in a group, and it covers the full + # span (because it's easier), and it has no mutations, delete it. + condition = ( + ts.nodes_flags[u] == core.NODE_IS_IMMEDIATE_REVERSION_MARKER + and is_full_span(tree, u) + and all(is_full_span(tree, v) for v in tree.children(u)) + and np.sum(ts.mutations_node == u) == 0 + ) + if condition: + nodes_to_delete.append(u) + + if len(nodes_to_delete) == 0: + return ts + + # This is all quite a roundabout way of removing a node from the + # tree we shouldn't be adding in the first place. There must be a + # better way. + tables = ts.dump_tables() + edges_to_delete = [] + for u in nodes_to_delete: + edges_to_delete.append(tree.edge(u)) + parent = tree.parent(u) + assert tree.num_children(u) > 0 + for v in tree.children(u): + e = tree.edge(v) + tables.edges[e] = ts.edge(e).replace(parent=parent) + + keep_edges = np.ones(ts.num_edges, dtype=bool) + keep_edges[edges_to_delete] = 0 + tables.edges.keep_rows(keep_edges) + keep_nodes = np.ones(ts.num_nodes, dtype=bool) + keep_nodes[nodes_to_delete] = 0 + node_map = tables.nodes.keep_rows(keep_nodes) + tables.edges.child = node_map[tables.edges.child] + tables.edges.parent = node_map[tables.edges.parent] + tables.mutations.node = node_map[tables.mutations.node] + tables.sort() + tables.build_index() + logger.debug(f"Deleted {len(nodes_to_delete)} immediate reversion nodes") + return tables.tree_sequence() + + # NOTE: could definitely do better here by using int encoding instead of # strings, and then njit @numba.jit(forceobj=True) @@ -1698,16 +1768,20 @@ def attach_tree( ) if len(group.immediate_reversions) > 0: - # Flag the node as an NODE_IS_IMMEDIATE_REVERSION_PARENT. - # It should be removed, along with the mutations we're adding here by + # Flag the node as an NODE_IS_IMMEDIATE_REVERSION_MARKER, which we've + # added as a unary above-the-root note above. + # This should be removed, along with the mutations we're adding here by # push_up_reversions in all cases except recombinants (which we've wussed # out on handling properly). + # This is all very roundabout, and we're also missing the opportunity + # to remove any non-immediate reversions if they exist withing + # the local tree group. node = tree.children(tree.root)[0] assert tree.num_children(tree.root) == 1 u = node_id_map[node] row = parent_tables.nodes[u] parent_tables.nodes[u] = row.replace( - flags=core.NODE_IS_IMMEDIATE_REVERSION_PARENT + flags=core.NODE_IS_IMMEDIATE_REVERSION_MARKER ) # print("attaching reversions at ", node, node_id_map[node]) # print(child_ts.draw_text()) diff --git a/tests/test_inference.py b/tests/test_inference.py index 849343f..f0a42ed 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -636,7 +636,7 @@ def test_2020_02_08(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_d rp_node = ts.node(tree.parent(node.id)) assert rp_node.flags == sc2ts.NODE_IS_REVERSION_PUSH # TODO not too sure how helpful this metadata actually is, but lets - assert rp_node.metadata["sc2ts"] == {"sample": node.id, "sites": [4923]} + # assert rp_node.metadata["sc2ts"] == {"sample": node.id, "sites": [4923]} ts.tables.assert_equals(fx_ts_map["2020-02-08"].tables, ignore_provenance=True) sib_sample = ts.node(tree.siblings(node.id)[0]) @@ -804,7 +804,6 @@ def test_exact_matches(self, fx_ts_map, strain, parent): class TestMatchingDetails: - @pytest.mark.parametrize( ("strain", "parent"), [("SRR11597207", 39), ("ERR4205570", 54)] ) From c93d2c972c1b6a1e34a26fb16b0a1ccb3770d3da Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 17 Sep 2024 10:28:52 +0100 Subject: [PATCH 3/4] Fixup metadata and test --- sc2ts/inference.py | 9 ++++++--- tests/test_inference.py | 3 +-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 8361de8..e974b2f 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -882,7 +882,7 @@ def add_matching_results( tables.build_index() tables.compute_mutation_parents() ts = tables.tree_sequence() - ts = push_up_reversions(ts, attach_nodes) + ts = push_up_reversions(ts, attach_nodes, date) ts = coalesce_mutations(ts, attach_nodes) ts = delete_immediate_reversion_nodes(ts, attach_nodes) return ts @@ -1119,7 +1119,7 @@ def coalesce_mutations(ts, samples=None): # NOTE: "samples" is a bad name here, this is actually the set of attach_nodes # that we get from making a local tree from a group. -def push_up_reversions(ts, samples): +def push_up_reversions(ts, samples, date): # We depend on mutations having a time below. assert np.all(np.logical_not(np.isnan(ts.mutations_time))) @@ -1192,8 +1192,11 @@ def push_up_reversions(ts, samples): time=w_time, metadata={ "sc2ts": { - "sample": int(sample), + # FIXME it's not clear how helpful the metadata is here + # If we had separate pass for each group, it would probably + # be easier to reason about. "sites": [int(x) for x in sites], + "date_added": date, } }, ) diff --git a/tests/test_inference.py b/tests/test_inference.py index f0a42ed..9fc1885 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -635,8 +635,7 @@ def test_2020_02_08(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_d tree = ts.first() rp_node = ts.node(tree.parent(node.id)) assert rp_node.flags == sc2ts.NODE_IS_REVERSION_PUSH - # TODO not too sure how helpful this metadata actually is, but lets - # assert rp_node.metadata["sc2ts"] == {"sample": node.id, "sites": [4923]} + assert rp_node.metadata["sc2ts"] == {"date_added": "2020-02-08", "sites": [4923]} ts.tables.assert_equals(fx_ts_map["2020-02-08"].tables, ignore_provenance=True) sib_sample = ts.node(tree.siblings(node.id)[0]) From 2e5a13ef0a5cdc8b8132abb345c927e559efb95c Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 17 Sep 2024 10:35:31 +0100 Subject: [PATCH 4/4] Fixup tests and add counters --- sc2ts/inference.py | 2 +- sc2ts/info.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index e974b2f..4bfe232 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -1119,7 +1119,7 @@ def coalesce_mutations(ts, samples=None): # NOTE: "samples" is a bad name here, this is actually the set of attach_nodes # that we get from making a local tree from a group. -def push_up_reversions(ts, samples, date): +def push_up_reversions(ts, samples, date="1999-01-01"): # We depend on mutations having a time below. assert np.all(np.logical_not(np.isnan(ts.mutations_time))) diff --git a/sc2ts/info.py b/sc2ts/info.py index ab08444..4a51ac0 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -14,6 +14,7 @@ # https://gist.github.com/alimanfoo/c5977e87111abe8127453b21204c1065 + def find_runs(x): """Find runs of consecutive items in an array.""" @@ -110,15 +111,18 @@ def node_counts(self): pr_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_REVERSION_PUSH) re_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_RECOMBINANT) exact_matches = np.sum((self.ts.nodes_flags & core.NODE_IS_EXACT_MATCH) > 0) + immediate_reversion_marker = np.sum( + (self.ts.nodes_flags & core.NODE_IS_IMMEDIATE_REVERSION_MARKER) > 0 + ) nodes_with_zero_muts = np.sum(self.nodes_num_mutations == 0) - non_samples = (self.ts.nodes_flags & tskit.NODE_IS_SAMPLE) == 0 return { "sample": self.ts.num_samples, "ex": exact_matches, "mc": mc_nodes, "pr": pr_nodes, "re": re_nodes, + "imr": immediate_reversion_marker, "zero_muts": nodes_with_zero_muts, } @@ -275,6 +279,9 @@ def summary(self): pr_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_REVERSION_PUSH) re_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_RECOMBINANT) exact_matches = np.sum((self.ts.nodes_flags & core.NODE_IS_EXACT_MATCH) > 0) + imr_nodes = np.sum( + (self.ts.nodes_flags & core.NODE_IS_IMMEDIATE_REVERSION_MARKER) > 0 + ) samples = self.ts.samples() nodes_with_zero_muts = np.sum(self.nodes_num_mutations == 0) @@ -294,6 +301,7 @@ def summary(self): ("mc_nodes", mc_nodes), ("pr_nodes", pr_nodes), ("re_nodes", re_nodes), + ("imr_nodes", imr_nodes), ("mutations", self.ts.num_mutations), ("recurrent", np.sum(self.ts.mutations_parent != -1)), ("reversions", np.sum(self.mutations_is_reversion)),