Skip to content

Commit

Permalink
Merge pull request #276 from jeromekelleher/fixup-tree-info
Browse files Browse the repository at this point in the history
Fixup tree info
  • Loading branch information
jeromekelleher authored Sep 16, 2024
2 parents 4c4519e + 266e92e commit 7b9b5c4
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 47 deletions.
6 changes: 4 additions & 2 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,11 @@ def info_ts(ts_path, verbose, log_file):
setup_logging(verbose, log_file)
ts = tszip.load(ts_path)

ti = sc2ts.TreeInfo(ts, quick=True)
print("info", ti.node_counts())
ti = sc2ts.TreeInfo(ts, quick=False)
# print("info", ti.node_counts())
print(ti.summary())
# TODO more
# print(ti.recombinants_summary())

def add_provenance(ts, output_file):
# Record provenance here because this is where the arguments are provided.
Expand Down
55 changes: 40 additions & 15 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,41 @@
import tskit
import numpy as np
import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colors
from IPython.display import Markdown, HTML

from . import core


def get_recombinants(ts):
partial_edges = np.logical_or(
ts.edges_left != 0, ts.edges_right != ts.sequence_length
)
recomb_nodes = np.unique(ts.edges_child[partial_edges])
return recomb_nodes
# https://gist.github.com/alimanfoo/c5977e87111abe8127453b21204c1065

def find_runs(x):
"""Find runs of consecutive items in an array."""

# ensure array
x = np.asanyarray(x)
if x.ndim != 1:
raise ValueError("only 1D array supported")
n = x.shape[0]

# handle empty array
if n == 0:
return np.array([]), np.array([]), np.array([])

else:
# find run starts
loc_run_start = np.empty(n, dtype=bool)
loc_run_start[0] = True
np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
run_starts = np.nonzero(loc_run_start)[0]

# find run values
run_values = x[loc_run_start]
# find run lengths
run_lengths = np.diff(np.append(run_starts, n))
return run_values, run_starts, run_lengths


def max_descendant_samples(ts, show_progress=True):
Expand Down Expand Up @@ -53,7 +78,7 @@ def __init__(
self.ts = ts
self.pango_source = pango_source
self.strain_map = {}
self.recombinants = get_recombinants(ts)
self.recombinants = np.where(ts.nodes_flags == core.NODE_IS_RECOMBINANT)[0]

self.nodes_max_descendant_samples = None
self.nodes_date = None
Expand Down Expand Up @@ -98,6 +123,7 @@ def node_counts(self):
}

def _preprocess_nodes(self, show_progress):
ts = self.ts
self.nodes_max_descendant_samples = max_descendant_samples(ts)
self.nodes_date = np.zeros(ts.num_nodes, dtype="datetime64[D]")
self.nodes_num_masked_sites = np.zeros(ts.num_nodes, dtype=np.int32)
Expand All @@ -120,15 +146,14 @@ def _preprocess_nodes(self, show_progress):
md = node.metadata
self.nodes_metadata[node.id] = md
if node.is_sample():
self.strain_map[md["strain"]] = node.id
self.nodes_date[node.id] = md["date"]
pango = md.get(pango_source, "unknown")
pango = md.get(self.pango_source, "unknown")
self.pango_lineage_samples[pango].append(node.id)
if "sc2ts" in md:
try:
qc = md["sc2ts"]["qc"]
self.nodes_num_masked_sites[node.id] = qc["num_masked_sites"]
else:
if node.id != 1:
except KeyError:
if node.id > 1:
warnings.warn("Node QC metadata not available")
else:
# Rounding down here, might be misleading
Expand All @@ -138,8 +163,8 @@ def _preprocess_nodes(self, show_progress):

def _preprocess_sites(self, show_progress):
self.sites_num_masked_samples = np.zeros(self.ts.num_sites, dtype=int)
if ts.table_metadata_schemas.site.schema is not None:
for site in ts.sites():
if self.ts.table_metadata_schemas.site.schema is not None:
for site in self.ts.sites():
self.sites_num_masked_samples[site.id] = site.metadata["masked_samples"]
else:
warnings.warn("Site QC metadata unavailable")
Expand Down Expand Up @@ -245,6 +270,7 @@ def _preprocess_mutations(self, show_progress):
self.sites_num_transversions = sites_num_transversions

def summary(self):
# TODO use the node_counts function above
mc_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_MUTATION_OVERLAP)
pr_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_REVERSION_PUSH)
re_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_RECOMBINANT)
Expand All @@ -268,7 +294,6 @@ def summary(self):
("mc_nodes", mc_nodes),
("pr_nodes", pr_nodes),
("re_nodes", re_nodes),
("recombinants", len(self.recombinants)),
("mutations", self.ts.num_mutations),
("recurrent", np.sum(self.ts.mutations_parent != -1)),
("reversions", np.sum(self.mutations_is_reversion)),
Expand Down
30 changes: 0 additions & 30 deletions sc2ts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,36 +253,6 @@ def asdict(self):
return dataclasses.asdict(self)


# https://gist.github.com/alimanfoo/c5977e87111abe8127453b21204c1065
def find_runs(x):
"""Find runs of consecutive items in an array."""

# ensure array
x = np.asanyarray(x)
if x.ndim != 1:
raise ValueError("only 1D array supported")
n = x.shape[0]

# handle empty array
if n == 0:
return np.array([]), np.array([]), np.array([])

else:
# find run starts
loc_run_start = np.empty(n, dtype=bool)
loc_run_start[0] = True
np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
run_starts = np.nonzero(loc_run_start)[0]

# find run values
run_values = x[loc_run_start]
# find run lengths
run_lengths = np.diff(np.append(run_starts, n))
return run_values, run_starts, run_lengths




def pad_sites(ts):
"""
Fill in missing sites with the reference state.
Expand Down

0 comments on commit 7b9b5c4

Please sign in to comment.