From 2d33ad3bf2eec57d9101018d612fbfbca335469a Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 25 Sep 2024 17:10:45 +0100 Subject: [PATCH 1/2] Minor updates --- sc2ts/info.py | 121 +++++++++++++++++++++++++++++++------------------- 1 file changed, 76 insertions(+), 45 deletions(-) diff --git a/sc2ts/info.py b/sc2ts/info.py index 580cf5f..8029e67 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -1399,9 +1399,16 @@ def get_sample_group_info(self, group_id): if self.ts.nodes_flags[u] & tskit.NODE_IS_SAMPLE > 0: samples.append(u) - tree = self.ts.first() - while self.nodes_metadata[u]["sc2ts"].get("group_id", None) == group_id: + tree = self.ts.first(tracked_samples=samples) + while tree.num_tracked_samples(u) < len(samples): u = tree.parent(u) + assert tree.num_tracked_samples(u) == len(samples) + # Now go up one more + # NOTE: I'm not sure if this is what we want to do in all cases + # because we only add on these extra branches when there are root + # mutations and other corner cases. + u = tree.parent(u) + attach_node = u attach_date = self.nodes_date[u] ts = self.ts.simplify(samples + [u]) tables = ts.dump_tables() @@ -1429,6 +1436,7 @@ def get_sample_group_info(self, group_id): group_id, self.nodes_sample_group[group_id], ts=tables.tree_sequence(), + attach_node=attach_node, attach_date=attach_date, ) @@ -1438,9 +1446,47 @@ class SampleGroupInfo: group_id: str nodes: List[int] ts: tskit.TreeSequence + attach_node: int attach_date: None - def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_labels=None, style=None, highlight_universal_mutations=None, x_regions=None, **kwargs): + def get_sample_metadata(self, key): + ret = [] + for u in self.ts.samples(): + node = self.ts.node(u) + ret.append(node.metadata[key]) + return ret + + @property + def lineages(self): + return self.get_sample_metadata("Viridian_pangolin") + + @property + def strains(self): + return self.get_sample_metadata("strain") + + @property + def sample_dates(self): + return np.array(self.get_sample_metadata("date"), dtype="datetime64[D]") + + @property + def num_mutations(self): + return self.ts.num_mutations + + @property + def num_recurrent_mutations(self): + return np.sum(self.ts.mutations_parent != -1) + + def draw_svg( + self, + size=(800, 600), + time_scale=None, + y_axis=True, + mutation_labels=None, + style=None, + highlight_universal_mutations=True, + x_regions=None, + **kwargs, + ): """ Draw an SVG representation of the tree of samples that trace to a single origin. @@ -1466,7 +1512,6 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label "Spike": (21563, 25384), } - if style is None: style = "" ts = self.ts @@ -1479,7 +1524,7 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label times = list(np.unique(ts.nodes_time)) y_ticks = {times.index(k): v for k, v in y_ticks.items()} shared_nodes = [] - if highlight_universal_mutations is not None: + if highlight_universal_mutations: # find edges above tree = ts.first() shared_nodes = [tree.root] @@ -1504,7 +1549,9 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label inherited_state = parent.derived_state parent_inherited_state = site.ancestral_state if parent.parent >= 0: - parent_inherited_state = ts.mutation(parent.parent).derived_state + parent_inherited_state = ts.mutation( + parent.parent + ).derived_state if parent_inherited_state == mut.derived_state: reverted_mutations.append(mut.id) # Reverse map label name to mutation id, so we can count duplicates @@ -1513,9 +1560,13 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label # If more than one mutation has the same label, add a prefix with the counts mutation_labels = { m_id: label + (f" ({i+1}/{len(ids)})" if len(ids) > 1 else "") - for label, ids in mutation_labels.items() for i, m_id in enumerate(ids)} + for label, ids in mutation_labels.items() + for i, m_id in enumerate(ids) + } # some default styles - styles = [".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}"] + styles = [ + ".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}" + ] if len(multiple_mutations) > 0: lab_css = ", ".join(f".mut.m{m} .lab" for m in multiple_mutations) sym_css = ", ".join(f".mut.m{m} .sym" for m in multiple_mutations) @@ -1527,8 +1578,12 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label if len(universal_mutations) > 0: lab_css = ", ".join(f".mut.m{m} .lab" for m in universal_mutations) sym_css = ", ".join(f".mut.m{m} .sym" for m in universal_mutations) - sym_ax_css = ", ".join(f".x-axis .mut.m{m} .sym" for m in universal_mutations) - styles.append(lab_css + "{font-weight: bold}" + sym_css + "{stroke-width: 3}") + sym_ax_css = ", ".join( + f".x-axis .mut.m{m} .sym" for m in universal_mutations + ) + styles.append( + lab_css + "{font-weight: bold}" + sym_css + "{stroke-width: 3}" + ) styles.append(sym_ax_css + "{stroke-width: 8}") svg = self.ts.draw_svg( size=size, @@ -1543,11 +1598,14 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label # Hack to add genes to the X axis if len(x_regions) > 0: assert svg.startswith("") + 1] + header = svg[: svg.find(">") + 1] footer = "" # Find SVG positions of the X axis - m = re.search(r'class="x-axis".*?class="ax-line" x1="([\d\.]+)" x2="([\d\.]+)" y1="([\d\.]+)"', svg) + m = re.search( + r'class="x-axis".*?class="ax-line" x1="([\d\.]+)" x2="([\d\.]+)" y1="([\d\.]+)"', + svg, + ) assert m is not None x1, x2, y1 = float(m.group(1)), float(m.group(2)), float(m.group(3)) xdiff = x2 - x1 @@ -1556,46 +1614,19 @@ def draw_svg(self, size=(800, 600), time_scale=None, y_axis=True, mutation_label x_scale = xdiff / ts.sequence_length x_boxes = [ x_box_svg.format( - x=x1 + p1 * x_scale, - w=(p2-p1) * x_scale, - y=y1, - h=20) # height of the box: hardcoded for now to match font height + x=x1 + p1 * x_scale, w=(p2 - p1) * x_scale, y=y1, h=20 + ) # height of the box: hardcoded for now to match font height for p1, p2 in x_regions.values() ] x_names = [ - x_name_svg.format(x=x1 + (p[0] + p[1])/2 * x_scale, y=y1+2, name=name) + x_name_svg.format( + x=x1 + (p[0] + p[1]) / 2 * x_scale, y=y1 + 2, name=name + ) for name, p in x_regions.items() ] # add the new SVG to the old - svg = (header + "".join(x_boxes) + "".join(x_names) + footer) + svg + svg = (header + "".join(x_boxes) + "".join(x_names) + footer) + svg # Now wrap both in another SVG svg = header + svg + footer return tskit.drawing.SVGString(svg) - - def get_sample_metadata(self, key): - ret = [] - for u in self.ts.samples(): - node = self.ts.node(u) - ret.append(node.metadata[key]) - return ret - - @property - def lineages(self): - return self.get_sample_metadata("Viridian_pangolin") - - @property - def strains(self): - return self.get_sample_metadata("strain") - - @property - def sample_dates(self): - return np.array(self.get_sample_metadata("date"), dtype="datetime64[D]") - - @property - def num_mutations(self): - return self.ts.num_mutations - - @property - def num_recurrent_mutations(self): - return np.sum(self.ts.mutations_parent != -1) From f9fd9636aa13d6bd97c79d7d9b51d5cc062e2be0 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 25 Sep 2024 17:15:43 +0100 Subject: [PATCH 2/2] Initial draft of pango lineage info --- sc2ts/info.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/sc2ts/info.py b/sc2ts/info.py index 8029e67..86c642a 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -1440,6 +1440,60 @@ def get_sample_group_info(self, group_id): attach_date=attach_date, ) + def pango_lineage_details(self, pango): + samples = self.pango_lineage_samples[pango] + first_node = samples[0] + first_node_group = self.nodes_metadata[first_node]["sc2ts"].get( + "group_id", None + ) + if first_node_group is None: + # print(self.ts.nodes_flags[first_node]) + assert self.ts.nodes_flags[first_node] & sc2ts.NODE_IS_EXACT_MATCH > 0 + sgi = None + tree = ts.first() + root = tree.parent(first_node) + else: + sgi = self.get_sample_group_info(first_node_group) + root = sgi.attach_node + + earliest_strain = self.nodes_metadata[first_node]["strain"] + earliest_strain_date = self.nodes_metadata[first_node]["date"] + descendant_counts = np.zeros(self.ts.num_trees, dtype=int) + for tree in self.ts.trees(tracked_samples=samples): + descendant_counts[tree.index] = tree.num_tracked_samples(root) + return PangoLineageDetails( + pango, + num_samples=len(samples), + descendant_counts=descendant_counts, + earliest_strain=earliest_strain, + earliest_strain_date=earliest_strain_date, + root_node=root, + group_id=first_node_group, + ) + + +@dataclasses.dataclass +class PangoLineageDetails: + name: str + num_samples: int + root_node: int + descendant_counts: None + earliest_strain: str + earliest_strain_date: str + group_id: str + + def summary_dict(self): + return { + "name": self.name, + "num_samples": self.num_samples, + "earliest_strain": self.earliest_strain, + "earliest_strain_date": self.earliest_strain_date, + "max_descendant_count": np.max(self.descendant_counts), + "min_descendant_count": np.min(self.descendant_counts), + "root_node": self.root_node, + "group_id": self.group_id, + } + @dataclasses.dataclass class SampleGroupInfo: