Skip to content

Commit

Permalink
Merge pull request #299 from jeromekelleher/stricter-filtering
Browse files Browse the repository at this point in the history
Stricter filtering
  • Loading branch information
jeromekelleher authored Sep 24, 2024
2 parents 745fb67 + f0989c9 commit 99048a5
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 13 deletions.
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ alignments=$datadir/alignments.db
metadata=$datadir/metadata.db
matches=$resultsdir/matches.db

options="--num-threads $num_threads -vv " #-l $logfile "
options="--num-threads $num_threads -vv -l $logfile "
# options+="--max-submission-delay $max_submission_delay "
# options+="--max-daily-samples $max_daily_samples "
options+="--num-mismatches $mismatches"
Expand Down
2 changes: 2 additions & 0 deletions sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
NODE_IS_RECOMBINANT = 1 << 23
NODE_IS_EXACT_MATCH = 1 << 24
NODE_IS_IMMEDIATE_REVERSION_MARKER = 1 << 25
NODE_IN_SAMPLE_GROUP = 1 << 26
NODE_IN_RETROSPECTIVE_SAMPLE_GROUP = 1 << 27


__version__ = "undefined"
Expand Down
47 changes: 35 additions & 12 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,9 @@ def extend(
match_db=match_db,
date=date,
min_group_size=1,
additional_node_flags=core.NODE_IN_SAMPLE_GROUP,
show_progress=show_progress,
phase="add(close)",
phase="close",
)

logger.info("Looking for retrospective matches")
Expand All @@ -614,10 +615,12 @@ def extend(
match_db=match_db,
date=date,
min_group_size=min_group_size,
min_different_dates=3, # TODO parametrize
additional_group_metadata_keys=["Country"],
min_different_dates=3, # TODO parametrise
additional_group_metadata_keys=["Country"], # TODO parametrise
min_root_mutations=3, # TODO parametrise
additional_node_flags=core.NODE_IN_RETROSPECTIVE_SAMPLE_GROUP,
show_progress=show_progress,
phase="add(retro)",
phase="retro",
)
return update_top_level_metadata(ts, date)

Expand Down Expand Up @@ -747,7 +750,7 @@ def __iter__(self):
def summary(self):
return (
f"Group {self.sample_hash} {len(self.samples)} samples "
f"({dict(self.date_count)}) "
f"{dict(self.date_count)} "
f"attaching at {path_summary(self.path)}, "
f"immediate_reversions={self.immediate_reversions}, "
f"additional_keys={self.additional_keys};"
Expand All @@ -762,6 +765,8 @@ def add_matching_results(
date,
min_group_size=1,
min_different_dates=1,
min_root_mutations=0,
additional_node_flags=None,
show_progress=False,
additional_group_metadata_keys=list(),
phase=None,
Expand Down Expand Up @@ -804,13 +809,16 @@ def add_matching_results(
tables = ts.dump_tables()

attach_nodes = []
with get_progress(groups, date, phase, show_progress) as bar:
with get_progress(groups, date, f"add({phase})", show_progress) as bar:
for group in bar:
if (
len(group) < min_group_size
or len(group.date_count) < min_different_dates
):
logger.debug(f"Skipping {group.summary()}")
logger.debug(
f"Skipping size={len(group)} dates={len(group.date_count)}: "
f"{group.summary()}"
)
continue

for sample in group:
Expand All @@ -824,17 +832,27 @@ def add_matching_results(
# print(binary_ts.draw_text())
# print(binary_ts.tables.mutations)
poly_ts = trim_branches(binary_ts)

# print(poly_ts.draw_text())
# print(poly_ts.tables.mutations)
# print("----")
assert poly_ts.num_samples == flat_ts.num_samples
tree = poly_ts.first()
num_root_mutations = np.sum(poly_ts.mutations_node == tree.root)
num_recurrent_mutations = np.sum(poly_ts.mutations_parent != -1)
if num_root_mutations < min_root_mutations:
logger.debug(
f"Skipping root_mutations={num_root_mutations}: "
f"{group.summary()}"
)
continue
attach_depth = max(tree.depth(u) for u in poly_ts.samples())
nodes = attach_tree(ts, tables, group, poly_ts, date)
nodes = attach_tree(ts, tables, group, poly_ts, date, additional_node_flags)
logger.debug(
f"{group.summary()}; "
f"depth={attach_depth} mutations={poly_ts.num_mutations} "
f"attach_nodes={nodes}"
f"Attach {phase} {group.summary()}; "
f"depth={attach_depth} total_muts{poly_ts.num_mutations} "
f"root_muts={num_root_mutations} "
f"recurrent_muts={num_recurrent_mutations} attach_nodes={nodes}"
)
attach_nodes.extend(nodes)

Expand Down Expand Up @@ -1683,6 +1701,7 @@ def attach_tree(
group,
child_ts,
date,
additional_node_flags,
epsilon=None,
):
attach_path = group.path
Expand Down Expand Up @@ -1741,7 +1760,11 @@ def attach_tree(
"date_added": date,
}
}
new_id = parent_tables.nodes.append(node.replace(time=time, metadata=metadata))
new_id = parent_tables.nodes.append(
node.replace(
flags=node.flags | additional_node_flags, time=time, metadata=metadata
)
)
node_id_map[node.id] = new_id
for v in tree.children(u):
parent_tables.edges.add_row(
Expand Down
6 changes: 6 additions & 0 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def node_counts(self):
pr_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_REVERSION_PUSH)
re_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_RECOMBINANT)
exact_matches = np.sum((self.ts.nodes_flags & core.NODE_IS_EXACT_MATCH) > 0)
sg_nodes = np.sum((self.ts.nodes_flags & core.NODE_IN_SAMPLE_GROUP) > 0)
rsg_nodes = np.sum(
(self.ts.nodes_flags & core.NODE_IN_RETROSPECTIVE_SAMPLE_GROUP) > 0
)
immediate_reversion_marker = np.sum(
(self.ts.nodes_flags & core.NODE_IS_IMMEDIATE_REVERSION_MARKER) > 0
)
Expand All @@ -336,6 +340,8 @@ def node_counts(self):
"mc": mc_nodes,
"pr": pr_nodes,
"re": re_nodes,
"sg": sg_nodes,
"rsg": rsg_nodes,
"imr": immediate_reversion_marker,
"zero_muts": nodes_with_zero_muts,
}
Expand Down
1 change: 1 addition & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def test_group(self, fx_ts_map, gid, date, internal, strains):
md = node.metadata
group = md["sc2ts"].get("group_id", None)
if group == gid:
assert node.flags & sc2ts.NODE_IN_SAMPLE_GROUP > 0
if node.is_sample():
got_strains.append(md["strain"])
assert md["date"] == date
Expand Down

0 comments on commit 99048a5

Please sign in to comment.