Skip to content

Commit

Permalink
Merge pull request #277 from jeromekelleher/remove-recurrent-muts
Browse files Browse the repository at this point in the history
Revert to older recurrent mutation handling approach
  • Loading branch information
jeromekelleher authored Sep 17, 2024
2 parents 7b9b5c4 + 2e5a13e commit 4d3acd2
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 32 deletions.
1 change: 1 addition & 0 deletions sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
NODE_IS_REVERSION_PUSH = 1 << 22
NODE_IS_RECOMBINANT = 1 << 23
NODE_IS_EXACT_MATCH = 1 << 24
NODE_IS_IMMEDIATE_REVERSION_MARKER = 1 << 25


__version__ = "undefined"
Expand Down
153 changes: 125 additions & 28 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def match_path_ts(group):
node_id = add_sample_to_tables(sample, tables, group_id=group.sample_hash)
tables.edges.add_row(0, tables.sequence_length, parent=0, child=node_id)
for mut in sample.mutations:
if (mut.site_id, mut.derived_state) in group.reversions:
if (mut.site_id, mut.derived_state) in group.immediate_reversions:
# We don't include any of the marked reversions so that they
# aren't used in tree building.
continue
Expand Down Expand Up @@ -757,7 +757,7 @@ class SampleGroup:

samples: List = None
path: List = None
reversions: List = None
immediate_reversions: List = None
sample_hash: str = None
date_count: dict = dataclasses.field(default_factory=collections.Counter)

Expand Down Expand Up @@ -786,7 +786,8 @@ def summary(self):
f"Group {self.sample_hash} {len(self.samples)} samples "
f"({dict(self.date_count)}) "
f"attaching at {path_summary(self.path)} and "
f"reversions={self.reversions}; strains={self.strains}"
f"immediate_reversions={self.immediate_reversions}; "
f"strains={self.strains}"
)


Expand All @@ -808,21 +809,21 @@ def add_matching_results(
num_samples = 0
for sample in match_db.get(where_clause):
path = tuple(sample.path)
reversions = tuple(
immediate_reversions = tuple(
(mut.site_id, mut.derived_state)
for mut in sample.mutations
if mut.is_immediate_reversion
)
grouped_matches[(path, reversions)].append(sample)
grouped_matches[(path, immediate_reversions)].append(sample)
num_samples += 1

if num_samples == 0:
logger.info("No candidate samples found in MatchDb")
return ts

groups = [
SampleGroup(samples, path, reversions)
for (path, reversions), samples in grouped_matches.items()
SampleGroup(samples, path, immediate_reversions)
for (path, immediate_reversions), samples in grouped_matches.items()
]
logger.info(f"Got {len(groups)} groups for {num_samples} samples")

Expand Down Expand Up @@ -870,12 +871,20 @@ def add_matching_results(
md["masked_samples"] += int(site_masked_samples[int(site.position)])
tables.sites.append(site.replace(metadata=md))

# NOTE: Doing the parsimony hueristic updates really is complicated a lot
# by doing all of group batches together. It should be simpler if we reason
# about *one* tree being added at a time. We might get hit by a high-cost
# in terms of creating tree sequence objects over and again, but maybe we
# can do less sorting by thinking more clearly about how to add the edges.
# If we only add edges pointing to one parent at a time, we should be able
# to just insert them into the middle of the table?
tables.sort()
tables.build_index()
tables.compute_mutation_parents()
ts = tables.tree_sequence()
ts = push_up_reversions(ts, attach_nodes)
ts = push_up_reversions(ts, attach_nodes, date)
ts = coalesce_mutations(ts, attach_nodes)
ts = delete_immediate_reversion_nodes(ts, attach_nodes)
return ts


Expand Down Expand Up @@ -968,6 +977,7 @@ def update_tables(tables, edges_to_delete, mutations_to_delete):
# Updating the mutations is a real faff, and the only way I
# could get it to work is by setting the time values. This should
# be easier...
# NOTE: this should be easier to do now that we have the "keep_rows" methods
mutations_to_keep = np.ones(len(tables.mutations), dtype=bool)
mutations_to_keep[mutations_to_delete] = False
tables.mutations.replace_with(tables.mutations[mutations_to_keep])
Expand Down Expand Up @@ -1107,7 +1117,9 @@ def coalesce_mutations(ts, samples=None):
return update_tables(tables, edges_to_delete, mutations_to_delete)


def push_up_reversions(ts, samples):
# NOTE: "samples" is a bad name here, this is actually the set of attach_nodes
# that we get from making a local tree from a group.
def push_up_reversions(ts, samples, date="1999-01-01"):
# We depend on mutations having a time below.
assert np.all(np.logical_not(np.isnan(ts.mutations_time)))

Expand Down Expand Up @@ -1180,8 +1192,11 @@ def push_up_reversions(ts, samples):
time=w_time,
metadata={
"sc2ts": {
"sample": int(sample),
# FIXME it's not clear how helpful the metadata is here
# If we had separate pass for each group, it would probably
# be easier to reason about.
"sites": [int(x) for x in sites],
"date_added": date,
}
},
)
Expand All @@ -1191,18 +1206,18 @@ def push_up_reversions(ts, samples):
tables.edges.add_row(0, ts.sequence_length, parent=w, child=sample)
tables.edges.add_row(0, ts.sequence_length, parent=grandparent, child=w)

# Move any non-reversions mutations above the parent to the new node.
for mut in np.where(ts.mutations_node == parent)[0]:
row = tables.mutations[mut]
if row.site not in sites:
tables.mutations[mut] = row.replace(node=w, time=w_time)
for site in sites:
# Delete the reversion mutations above the sample
muts = np.where(
np.logical_and(ts.mutations_node == sample, ts.mutations_site == site)
)[0]
assert len(muts) == 1
mutations_to_delete.extend(muts)
# Move any non-reversions mutations above the parent to the new node.
for mut in np.where(ts.mutations_node == parent)[0]:
row = tables.mutations[mut]
if row.site not in sites:
tables.mutations[mut] = row.replace(node=w, time=w_time)

num_del_mutations = len(mutations_to_delete)
num_new_nodes = len(tables.nodes) - ts.num_nodes
Expand All @@ -1213,6 +1228,65 @@ def push_up_reversions(ts, samples):
return update_tables(tables, edges_to_delete, mutations_to_delete)


def is_full_span(tree, u):
"""
Returns true if the edge in which the specified node is a child
covers the full span of the tree sequence.
"""
ts = tree.tree_sequence
e = tree.edge(u)
assert e != -1
edge = ts.edge(e)
return edge.left == 0 and edge.right == ts.sequence_length


def delete_immediate_reversion_nodes(ts, attach_nodes):
tree = ts.first()
nodes_to_delete = []
for u in attach_nodes:
# If a node is a node inserted to track the immediate reversions
# shared by all the samples in a group, and it covers the full
# span (because it's easier), and it has no mutations, delete it.
condition = (
ts.nodes_flags[u] == core.NODE_IS_IMMEDIATE_REVERSION_MARKER
and is_full_span(tree, u)
and all(is_full_span(tree, v) for v in tree.children(u))
and np.sum(ts.mutations_node == u) == 0
)
if condition:
nodes_to_delete.append(u)

if len(nodes_to_delete) == 0:
return ts

# This is all quite a roundabout way of removing a node from the
# tree we shouldn't be adding in the first place. There must be a
# better way.
tables = ts.dump_tables()
edges_to_delete = []
for u in nodes_to_delete:
edges_to_delete.append(tree.edge(u))
parent = tree.parent(u)
assert tree.num_children(u) > 0
for v in tree.children(u):
e = tree.edge(v)
tables.edges[e] = ts.edge(e).replace(parent=parent)

keep_edges = np.ones(ts.num_edges, dtype=bool)
keep_edges[edges_to_delete] = 0
tables.edges.keep_rows(keep_edges)
keep_nodes = np.ones(ts.num_nodes, dtype=bool)
keep_nodes[nodes_to_delete] = 0
node_map = tables.nodes.keep_rows(keep_nodes)
tables.edges.child = node_map[tables.edges.child]
tables.edges.parent = node_map[tables.edges.parent]
tables.mutations.node = node_map[tables.mutations.node]
tables.sort()
tables.build_index()
logger.debug(f"Deleted {len(nodes_to_delete)} immediate reversion nodes")
return tables.tree_sequence()


# NOTE: could definitely do better here by using int encoding instead of
# strings, and then njit
@numba.jit(forceobj=True)
Expand Down Expand Up @@ -1611,7 +1685,6 @@ def attach_tree(
epsilon=None,
):
attach_path = group.path
reversions = group.reversions
if epsilon is None:
epsilon = 1e-6 # In time units of days ago

Expand All @@ -1624,7 +1697,11 @@ def attach_tree(
raise ValueError("Incompatible sequence length")

tree = child_ts.first()
condition = np.any(child_ts.mutations_node == tree.root) or len(attach_path) > 1
condition = (
np.any(child_ts.mutations_node == tree.root)
or len(attach_path) > 1
or len(group.immediate_reversions) > 0
)
if condition:
child_ts = add_root_edge(child_ts)
tree = child_ts.first()
Expand Down Expand Up @@ -1677,15 +1754,6 @@ def attach_tree(
parent_tables.edges.add_row(
seg.left, seg.right, parent=seg.parent, child=node_id_map[child]
)
# Add the reversion mutations over the attach nodes. These will be picked up
# by the reversion push code below, so no point in setting metadata.
for site_id, derived_state in reversions:
parent_tables.mutations.add_row(
site=site_id,
node=node_id_map[child],
derived_state=derived_state,
time=node_time[child],
)

# Add the mutations.
for site in child_ts.sites():
Expand All @@ -1702,6 +1770,35 @@ def attach_tree(
},
)

if len(group.immediate_reversions) > 0:
# Flag the node as an NODE_IS_IMMEDIATE_REVERSION_MARKER, which we've
# added as a unary above-the-root note above.
# This should be removed, along with the mutations we're adding here by
# push_up_reversions in all cases except recombinants (which we've wussed
# out on handling properly).
# This is all very roundabout, and we're also missing the opportunity
# to remove any non-immediate reversions if they exist withing
# the local tree group.
node = tree.children(tree.root)[0]
assert tree.num_children(tree.root) == 1
u = node_id_map[node]
row = parent_tables.nodes[u]
parent_tables.nodes[u] = row.replace(
flags=core.NODE_IS_IMMEDIATE_REVERSION_MARKER
)
# print("attaching reversions at ", node, node_id_map[node])
# print(child_ts.draw_text())
for site_id, derived_state in group.immediate_reversions:
parent_tables.mutations.add_row(
site=site_id,
node=u,
derived_state=derived_state,
time=node_time[node],
metadata={
"sc2ts": {"type": "match_reversion", "group_id": group.sample_hash}
},
)

if len(attach_path) > 1:
# Update the recombinant flags also.
u = node_id_map[tree.children(tree.root)[0]]
Expand All @@ -1711,7 +1808,7 @@ def attach_tree(
return [node_id_map[u] for u in tree.children(tree.root)]


def add_root_edge(ts):
def add_root_edge(ts, flags=0):
"""
Add another node and edge above the root and rescale time back to
0-1.
Expand All @@ -1721,7 +1818,7 @@ def add_root_edge(ts):
root = ts.first().root
# FIXME this is bogus. We should be doing all the time scaling by numbers
# of mutations.
new_root = tables.nodes.add_row(time=1.25)
new_root = tables.nodes.add_row(time=1.25, flags=flags)
tables.edges.add_row(0, ts.sequence_length, parent=new_root, child=root)
tables.nodes.time /= np.max(tables.nodes.time)
return tables.tree_sequence()
10 changes: 9 additions & 1 deletion sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# https://gist.github.com/alimanfoo/c5977e87111abe8127453b21204c1065


def find_runs(x):
"""Find runs of consecutive items in an array."""

Expand Down Expand Up @@ -110,15 +111,18 @@ def node_counts(self):
pr_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_REVERSION_PUSH)
re_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_RECOMBINANT)
exact_matches = np.sum((self.ts.nodes_flags & core.NODE_IS_EXACT_MATCH) > 0)
immediate_reversion_marker = np.sum(
(self.ts.nodes_flags & core.NODE_IS_IMMEDIATE_REVERSION_MARKER) > 0
)

nodes_with_zero_muts = np.sum(self.nodes_num_mutations == 0)
non_samples = (self.ts.nodes_flags & tskit.NODE_IS_SAMPLE) == 0
return {
"sample": self.ts.num_samples,
"ex": exact_matches,
"mc": mc_nodes,
"pr": pr_nodes,
"re": re_nodes,
"imr": immediate_reversion_marker,
"zero_muts": nodes_with_zero_muts,
}

Expand Down Expand Up @@ -275,6 +279,9 @@ def summary(self):
pr_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_REVERSION_PUSH)
re_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_RECOMBINANT)
exact_matches = np.sum((self.ts.nodes_flags & core.NODE_IS_EXACT_MATCH) > 0)
imr_nodes = np.sum(
(self.ts.nodes_flags & core.NODE_IS_IMMEDIATE_REVERSION_MARKER) > 0
)

samples = self.ts.samples()
nodes_with_zero_muts = np.sum(self.nodes_num_mutations == 0)
Expand All @@ -294,6 +301,7 @@ def summary(self):
("mc_nodes", mc_nodes),
("pr_nodes", pr_nodes),
("re_nodes", re_nodes),
("imr_nodes", imr_nodes),
("mutations", self.ts.num_mutations),
("recurrent", np.sum(self.ts.mutations_parent != -1)),
("reversions", np.sum(self.mutations_is_reversion)),
Expand Down
4 changes: 1 addition & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,7 @@ def test_2020_02_08(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_d
tree = ts.first()
rp_node = ts.node(tree.parent(node.id))
assert rp_node.flags == sc2ts.NODE_IS_REVERSION_PUSH
# TODO not too sure how helpful this metadata actually is, but lets
assert rp_node.metadata["sc2ts"] == {"sample": node.id, "sites": [4923]}
assert rp_node.metadata["sc2ts"] == {"date_added": "2020-02-08", "sites": [4923]}
ts.tables.assert_equals(fx_ts_map["2020-02-08"].tables, ignore_provenance=True)

sib_sample = ts.node(tree.siblings(node.id)[0])
Expand Down Expand Up @@ -804,7 +803,6 @@ def test_exact_matches(self, fx_ts_map, strain, parent):


class TestMatchingDetails:

@pytest.mark.parametrize(
("strain", "parent"), [("SRR11597207", 39), ("ERR4205570", 54)]
)
Expand Down

0 comments on commit 4d3acd2

Please sign in to comment.