Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lineage tools #305

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 130 additions & 45 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,9 +1399,16 @@ def get_sample_group_info(self, group_id):
if self.ts.nodes_flags[u] & tskit.NODE_IS_SAMPLE > 0:
samples.append(u)

tree = self.ts.first()
while self.nodes_metadata[u]["sc2ts"].get("group_id", None) == group_id:
tree = self.ts.first(tracked_samples=samples)
while tree.num_tracked_samples(u) < len(samples):
u = tree.parent(u)
assert tree.num_tracked_samples(u) == len(samples)
# Now go up one more
# NOTE: I'm not sure if this is what we want to do in all cases
# because we only add on these extra branches when there are root
# mutations and other corner cases.
u = tree.parent(u)
attach_node = u
attach_date = self.nodes_date[u]
ts = self.ts.simplify(samples + [u])
tables = ts.dump_tables()
Expand Down Expand Up @@ -1429,18 +1436,111 @@ def get_sample_group_info(self, group_id):
group_id,
self.nodes_sample_group[group_id],
ts=tables.tree_sequence(),
attach_node=attach_node,
attach_date=attach_date,
)

def pango_lineage_details(self, pango):
samples = self.pango_lineage_samples[pango]
first_node = samples[0]
first_node_group = self.nodes_metadata[first_node]["sc2ts"].get(
"group_id", None
)
if first_node_group is None:
# print(self.ts.nodes_flags[first_node])
assert self.ts.nodes_flags[first_node] & sc2ts.NODE_IS_EXACT_MATCH > 0
sgi = None
tree = ts.first()
root = tree.parent(first_node)
else:
sgi = self.get_sample_group_info(first_node_group)
root = sgi.attach_node

earliest_strain = self.nodes_metadata[first_node]["strain"]
earliest_strain_date = self.nodes_metadata[first_node]["date"]
descendant_counts = np.zeros(self.ts.num_trees, dtype=int)
for tree in self.ts.trees(tracked_samples=samples):
descendant_counts[tree.index] = tree.num_tracked_samples(root)
return PangoLineageDetails(
pango,
num_samples=len(samples),
descendant_counts=descendant_counts,
earliest_strain=earliest_strain,
earliest_strain_date=earliest_strain_date,
root_node=root,
group_id=first_node_group,
)


@dataclasses.dataclass
class PangoLineageDetails:
name: str
num_samples: int
root_node: int
descendant_counts: None
earliest_strain: str
earliest_strain_date: str
group_id: str

def summary_dict(self):
return {
"name": self.name,
"num_samples": self.num_samples,
"earliest_strain": self.earliest_strain,
"earliest_strain_date": self.earliest_strain_date,
"max_descendant_count": np.max(self.descendant_counts),
"min_descendant_count": np.min(self.descendant_counts),
"root_node": self.root_node,
"group_id": self.group_id,
}


@dataclasses.dataclass
class SampleGroupInfo:
group_id: str
nodes: List[int]
ts: tskit.TreeSequence
attach_node: int
attach_date: None

def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_labels=None, style=None, highlight_universal_mutations=None, x_regions=None, **kwargs):
def get_sample_metadata(self, key):
ret = []
for u in self.ts.samples():
node = self.ts.node(u)
ret.append(node.metadata[key])
return ret

@property
def lineages(self):
return self.get_sample_metadata("Viridian_pangolin")

@property
def strains(self):
return self.get_sample_metadata("strain")

@property
def sample_dates(self):
return np.array(self.get_sample_metadata("date"), dtype="datetime64[D]")

@property
def num_mutations(self):
return self.ts.num_mutations

@property
def num_recurrent_mutations(self):
return np.sum(self.ts.mutations_parent != -1)

def draw_svg(
self,
size=(800, 600),
time_scale=None,
y_axis=True,
mutation_labels=None,
style=None,
highlight_universal_mutations=True,
x_regions=None,
**kwargs,
):
"""
Draw an SVG representation of the tree of samples that trace to a single origin.

Expand All @@ -1466,7 +1566,6 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label
"Spike": (21563, 25384),
}


if style is None:
style = ""
ts = self.ts
Expand All @@ -1479,7 +1578,7 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label
times = list(np.unique(ts.nodes_time))
y_ticks = {times.index(k): v for k, v in y_ticks.items()}
shared_nodes = []
if highlight_universal_mutations is not None:
if highlight_universal_mutations:
# find edges above
tree = ts.first()
shared_nodes = [tree.root]
Expand All @@ -1504,7 +1603,9 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label
inherited_state = parent.derived_state
parent_inherited_state = site.ancestral_state
if parent.parent >= 0:
parent_inherited_state = ts.mutation(parent.parent).derived_state
parent_inherited_state = ts.mutation(
parent.parent
).derived_state
if parent_inherited_state == mut.derived_state:
reverted_mutations.append(mut.id)
# Reverse map label name to mutation id, so we can count duplicates
Expand All @@ -1513,9 +1614,13 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label
# If more than one mutation has the same label, add a prefix with the counts
mutation_labels = {
m_id: label + (f" ({i+1}/{len(ids)})" if len(ids) > 1 else "")
for label, ids in mutation_labels.items() for i, m_id in enumerate(ids)}
for label, ids in mutation_labels.items()
for i, m_id in enumerate(ids)
}
# some default styles
styles = [".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}"]
styles = [
".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}"
]
if len(multiple_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in multiple_mutations)
sym_css = ", ".join(f".mut.m{m} .sym" for m in multiple_mutations)
Expand All @@ -1527,8 +1632,12 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label
if len(universal_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in universal_mutations)
sym_css = ", ".join(f".mut.m{m} .sym" for m in universal_mutations)
sym_ax_css = ", ".join(f".x-axis .mut.m{m} .sym" for m in universal_mutations)
styles.append(lab_css + "{font-weight: bold}" + sym_css + "{stroke-width: 3}")
sym_ax_css = ", ".join(
f".x-axis .mut.m{m} .sym" for m in universal_mutations
)
styles.append(
lab_css + "{font-weight: bold}" + sym_css + "{stroke-width: 3}"
)
styles.append(sym_ax_css + "{stroke-width: 8}")
svg = self.ts.draw_svg(
size=size,
Expand All @@ -1543,11 +1652,14 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label
# Hack to add genes to the X axis
if len(x_regions) > 0:
assert svg.startswith("<svg")
header = svg[:svg.find(">") + 1]
header = svg[: svg.find(">") + 1]
footer = "</svg>"

# Find SVG positions of the X axis
m = re.search(r'class="x-axis".*?class="ax-line" x1="([\d\.]+)" x2="([\d\.]+)" y1="([\d\.]+)"', svg)
m = re.search(
r'class="x-axis".*?class="ax-line" x1="([\d\.]+)" x2="([\d\.]+)" y1="([\d\.]+)"',
svg,
)
assert m is not None
x1, x2, y1 = float(m.group(1)), float(m.group(2)), float(m.group(3))
xdiff = x2 - x1
Expand All @@ -1556,46 +1668,19 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label
x_scale = xdiff / ts.sequence_length
x_boxes = [
x_box_svg.format(
x=x1 + p1 * x_scale,
w=(p2-p1) * x_scale,
y=y1,
h=20) # height of the box: hardcoded for now to match font height
x=x1 + p1 * x_scale, w=(p2 - p1) * x_scale, y=y1, h=20
) # height of the box: hardcoded for now to match font height
for p1, p2 in x_regions.values()
]
x_names = [
x_name_svg.format(x=x1 + (p[0] + p[1])/2 * x_scale, y=y1+2, name=name)
x_name_svg.format(
x=x1 + (p[0] + p[1]) / 2 * x_scale, y=y1 + 2, name=name
)
for name, p in x_regions.items()
]
# add the new SVG to the old
svg = (header + "".join(x_boxes) + "".join(x_names) + footer) + svg
svg = (header + "".join(x_boxes) + "".join(x_names) + footer) + svg
# Now wrap both in another SVG
svg = header + svg + footer

return tskit.drawing.SVGString(svg)

def get_sample_metadata(self, key):
ret = []
for u in self.ts.samples():
node = self.ts.node(u)
ret.append(node.metadata[key])
return ret

@property
def lineages(self):
return self.get_sample_metadata("Viridian_pangolin")

@property
def strains(self):
return self.get_sample_metadata("strain")

@property
def sample_dates(self):
return np.array(self.get_sample_metadata("date"), dtype="datetime64[D]")

@property
def num_mutations(self):
return self.ts.num_mutations

@property
def num_recurrent_mutations(self):
return np.sum(self.ts.mutations_parent != -1)