Skip to content

Commit

Permalink
Fixup recombinant information gathering
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Sep 17, 2024
1 parent 8e5099a commit e4f6d2e
Showing 1 changed file with 94 additions and 3 deletions.
97 changes: 94 additions & 3 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,82 @@
from IPython.display import Markdown, HTML

from . import core
from . import utils


# https://gist.github.com/alimanfoo/c5977e87111abe8127453b21204c1065


def get_recombinant_edges(ts):
"""
Return the partial edges from the tree sequence grouped by child (which must
be flagged as a recombinant node).
"""
partial_edges = np.where(
np.logical_or(ts.edges_left != 0, ts.edges_right != ts.sequence_length)
)[0]
edges = collections.defaultdict(list)
for edge_id in partial_edges:
edge = ts.edge(edge_id)
assert ts.nodes_flags[edge.child] == sc2ts.NODE_IS_RECOMBINANT
edges[edge.child].append(edge)

# Check that they are in order and completely cover the region
for child_edges in edges.values():
child_edges.sort(key=lambda e: e.left)
assert len(child_edges) >= 2
assert child_edges[0].left == 0
assert child_edges[-1].right == ts.sequence_length
last_edge = child_edges[0]
for edge in child_edges[1:]:
assert edge.left == last_edge.right
last_edge = edge
return edges


def get_recombinant_mrca_table(ts):
"""
Return a pandas data frame of the recombinant breakpoints from the
specified tree sequence. For each partial edge (which must have a
node marked as NODE_IS_RECOMBINANT as child), return a row in
the dataframe giving the breakpoint, the left parent, right parent
and the most recent common ancestor of these parent nodes.
"""

recombinant_edges = get_recombinant_edges(ts)
# Split these up into adjacent pairs
breakpoint_pairs = []
for child, child_edges in recombinant_edges.items():
for j in range(len(child_edges) - 1):
assert child_edges[j].child == child
breakpoint_pairs.append((child_edges[j], child_edges[j + 1]))
assert len(breakpoint_pairs) >= len(recombinant_edges)

data = []
tree = ts.first()
for left_edge, right_edge in sorted(breakpoint_pairs, key=lambda x: x[1].left):
assert left_edge.right == right_edge.left
assert left_edge.child == right_edge.child
recombinant_node = left_edge.child
bp = left_edge.right
tree.seek(bp)
assert tree.interval.left == bp
right_path = get_root_path(tree, right_edge.parent)
tree.prev()
assert tree.interval.right == bp
left_path = get_root_path(tree, left_edge.parent)
mrca = get_path_mrca(left_path, right_path, ts.nodes_time)
row = {
"recombinant_node": recombinant_node,
"breakpoint": bp,
"left_parent": left_edge.parent,
"right_parent": right_edge.parent,
"mrca": mrca,
}
data.append(row)
return pd.DataFrame(data, dtype=np.int32)


def find_runs(x):
"""Find runs of consecutive items in an array."""

Expand Down Expand Up @@ -492,19 +563,39 @@ def site_summary(self, position):

def recombinants_summary(self):
df = self._collect_node_data(self.recombinants)
sample_map = get_recombinant_samples(self.ts)
sample_map = utils.get_recombinant_samples(self.ts)
causal_strain = []
causal_pango = []
causal_date = []
interval_left = []
interval_right = []
for u in df.node:
md = self.nodes_metadata[sample_map[u]]
causal_strain.append(md["strain"])
causal_pango.append(md[self.pango_source])
causal_date.append(md["date"])
try:
hmm_md = self.nodes_metadata[u]["sc2ts"]["hmm"]
# print(hmm_md)
assert hmm_md[0]["direction"] == "forward"
assert hmm_md[1]["direction"] == "reverse"
interval_left.append(hmm_md[1]["path"][0]["right"])
interval_right.append(hmm_md[0]["path"][0]["right"])
except KeyError:
interval_left.append(0)
interval_right.append(0)
df["causal_strain"] = causal_strain
df["causal_pango"] = causal_pango
df["causal_date"] = causal_date
return df
df["breakpoint_interval_left"] = interval_left
df["breakpoint_interval_right"] = interval_right

df = df.set_index("node")
mrca_table = utils.get_recombinant_mrca_table(self.ts).set_index(
"recombinant_node"
)
assert len(mrca_table) == len(df)
return df.join(mrca_table)

def combine_recombinant_info(self):
def get_imputed_pango(u, pango_source):
Expand All @@ -517,7 +608,7 @@ def get_imputed_pango(u, pango_source):
lineage = self.nodes_metadata[u]["Imputed_" + pango_source]
return lineage

df_arg = sc2ts.utils.get_recombinant_mrca_table(self.ts)
df_arg = utils.get_recombinant_mrca_table(self.ts)
arg_info = collections.defaultdict(list)
for _, row in df_arg.iterrows():
arg_info[row.recombinant_node].append(row)
Expand Down

0 comments on commit e4f6d2e

Please sign in to comment.