Skip to content

Commit

Permalink
Remove unary nodes from reversion push
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Sep 17, 2024
1 parent 082942e commit 3ebf4dd
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 6 deletions.
2 changes: 1 addition & 1 deletion sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
NODE_IS_REVERSION_PUSH = 1 << 22
NODE_IS_RECOMBINANT = 1 << 23
NODE_IS_EXACT_MATCH = 1 << 24
NODE_IS_IMMEDIATE_REVERSION_PARENT = 1 << 25
NODE_IS_IMMEDIATE_REVERSION_MARKER = 1 << 25


__version__ = "undefined"
Expand Down
80 changes: 77 additions & 3 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,12 +871,20 @@ def add_matching_results(
md["masked_samples"] += int(site_masked_samples[int(site.position)])
tables.sites.append(site.replace(metadata=md))

# NOTE: Doing the parsimony hueristic updates really is complicated a lot
# by doing all of group batches together. It should be simpler if we reason
# about *one* tree being added at a time. We might get hit by a high-cost
# in terms of creating tree sequence objects over and again, but maybe we
# can do less sorting by thinking more clearly about how to add the edges.
# If we only add edges pointing to one parent at a time, we should be able
# to just insert them into the middle of the table?
tables.sort()
tables.build_index()
tables.compute_mutation_parents()
ts = tables.tree_sequence()
ts = push_up_reversions(ts, attach_nodes)
ts = coalesce_mutations(ts, attach_nodes)
ts = delete_immediate_reversion_nodes(ts, attach_nodes)
return ts


Expand Down Expand Up @@ -969,6 +977,7 @@ def update_tables(tables, edges_to_delete, mutations_to_delete):
# Updating the mutations is a real faff, and the only way I
# could get it to work is by setting the time values. This should
# be easier...
# NOTE: this should be easier to do now that we have the "keep_rows" methods
mutations_to_keep = np.ones(len(tables.mutations), dtype=bool)
mutations_to_keep[mutations_to_delete] = False
tables.mutations.replace_with(tables.mutations[mutations_to_keep])
Expand Down Expand Up @@ -1108,6 +1117,8 @@ def coalesce_mutations(ts, samples=None):
return update_tables(tables, edges_to_delete, mutations_to_delete)


# NOTE: "samples" is a bad name here, this is actually the set of attach_nodes
# that we get from making a local tree from a group.
def push_up_reversions(ts, samples):
# We depend on mutations having a time below.
assert np.all(np.logical_not(np.isnan(ts.mutations_time)))
Expand Down Expand Up @@ -1214,6 +1225,65 @@ def push_up_reversions(ts, samples):
return update_tables(tables, edges_to_delete, mutations_to_delete)


def is_full_span(tree, u):
"""
Returns true if the edge in which the specified node is a child
covers the full span of the tree sequence.
"""
ts = tree.tree_sequence
e = tree.edge(u)
assert e != -1
edge = ts.edge(e)
return edge.left == 0 and edge.right == ts.sequence_length


def delete_immediate_reversion_nodes(ts, attach_nodes):
tree = ts.first()
nodes_to_delete = []
for u in attach_nodes:
# If a node is a node inserted to track the immediate reversions
# shared by all the samples in a group, and it covers the full
# span (because it's easier), and it has no mutations, delete it.
condition = (
ts.nodes_flags[u] == core.NODE_IS_IMMEDIATE_REVERSION_MARKER
and is_full_span(tree, u)
and all(is_full_span(tree, v) for v in tree.children(u))
and np.sum(ts.mutations_node == u) == 0
)
if condition:
nodes_to_delete.append(u)

if len(nodes_to_delete) == 0:
return ts

# This is all quite a roundabout way of removing a node from the
# tree we shouldn't be adding in the first place. There must be a
# better way.
tables = ts.dump_tables()
edges_to_delete = []
for u in nodes_to_delete:
edges_to_delete.append(tree.edge(u))
parent = tree.parent(u)
assert tree.num_children(u) > 0
for v in tree.children(u):
e = tree.edge(v)
tables.edges[e] = ts.edge(e).replace(parent=parent)

keep_edges = np.ones(ts.num_edges, dtype=bool)
keep_edges[edges_to_delete] = 0
tables.edges.keep_rows(keep_edges)
keep_nodes = np.ones(ts.num_nodes, dtype=bool)
keep_nodes[nodes_to_delete] = 0
node_map = tables.nodes.keep_rows(keep_nodes)
tables.edges.child = node_map[tables.edges.child]
tables.edges.parent = node_map[tables.edges.parent]
tables.mutations.node = node_map[tables.mutations.node]
tables.sort()
tables.build_index()
logger.debug(f"Deleted {len(nodes_to_delete)} immediate reversion nodes")
return tables.tree_sequence()


# NOTE: could definitely do better here by using int encoding instead of
# strings, and then njit
@numba.jit(forceobj=True)
Expand Down Expand Up @@ -1698,16 +1768,20 @@ def attach_tree(
)

if len(group.immediate_reversions) > 0:
# Flag the node as an NODE_IS_IMMEDIATE_REVERSION_PARENT.
# It should be removed, along with the mutations we're adding here by
# Flag the node as an NODE_IS_IMMEDIATE_REVERSION_MARKER, which we've
# added as a unary above-the-root note above.
# This should be removed, along with the mutations we're adding here by
# push_up_reversions in all cases except recombinants (which we've wussed
# out on handling properly).
# This is all very roundabout, and we're also missing the opportunity
# to remove any non-immediate reversions if they exist withing
# the local tree group.
node = tree.children(tree.root)[0]
assert tree.num_children(tree.root) == 1
u = node_id_map[node]
row = parent_tables.nodes[u]
parent_tables.nodes[u] = row.replace(
flags=core.NODE_IS_IMMEDIATE_REVERSION_PARENT
flags=core.NODE_IS_IMMEDIATE_REVERSION_MARKER
)
# print("attaching reversions at ", node, node_id_map[node])
# print(child_ts.draw_text())
Expand Down
3 changes: 1 addition & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def test_2020_02_08(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_d
rp_node = ts.node(tree.parent(node.id))
assert rp_node.flags == sc2ts.NODE_IS_REVERSION_PUSH
# TODO not too sure how helpful this metadata actually is, but lets
assert rp_node.metadata["sc2ts"] == {"sample": node.id, "sites": [4923]}
# assert rp_node.metadata["sc2ts"] == {"sample": node.id, "sites": [4923]}
ts.tables.assert_equals(fx_ts_map["2020-02-08"].tables, ignore_provenance=True)

sib_sample = ts.node(tree.siblings(node.id)[0])
Expand Down Expand Up @@ -804,7 +804,6 @@ def test_exact_matches(self, fx_ts_map, strain, parent):


class TestMatchingDetails:

@pytest.mark.parametrize(
("strain", "parent"), [("SRR11597207", 39), ("ERR4205570", 54)]
)
Expand Down

0 comments on commit 3ebf4dd

Please sign in to comment.