diff --git a/run.sh b/run.sh index 36e99f8..72d2c6b 100755 --- a/run.sh +++ b/run.sh @@ -20,7 +20,7 @@ alignments=$datadir/alignments.db metadata=$datadir/metadata.db matches=$resultsdir/matches.db -options="--num-threads $num_threads -vv " #-l $logfile " +options="--num-threads $num_threads -vv -l $logfile " # options+="--max-submission-delay $max_submission_delay " # options+="--max-daily-samples $max_daily_samples " options+="--num-mismatches $mismatches" diff --git a/sc2ts/core.py b/sc2ts/core.py index 494d67c..c7c424d 100644 --- a/sc2ts/core.py +++ b/sc2ts/core.py @@ -19,6 +19,8 @@ NODE_IS_RECOMBINANT = 1 << 23 NODE_IS_EXACT_MATCH = 1 << 24 NODE_IS_IMMEDIATE_REVERSION_MARKER = 1 << 25 +NODE_IN_SAMPLE_GROUP = 1 << 26 +NODE_IN_RETROSPECTIVE_SAMPLE_GROUP = 1 << 27 __version__ = "undefined" diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 1f6c7b4..63c75cb 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -601,8 +601,9 @@ def extend( match_db=match_db, date=date, min_group_size=1, + additional_node_flags=core.NODE_IN_SAMPLE_GROUP, show_progress=show_progress, - phase="add(close)", + phase="close", ) logger.info("Looking for retrospective matches") @@ -614,10 +615,12 @@ def extend( match_db=match_db, date=date, min_group_size=min_group_size, - min_different_dates=3, # TODO parametrize - additional_group_metadata_keys=["Country"], + min_different_dates=3, # TODO parametrise + additional_group_metadata_keys=["Country"], # TODO parametrise + min_root_mutations=3, # TODO parametrise + additional_node_flags=core.NODE_IN_RETROSPECTIVE_SAMPLE_GROUP, show_progress=show_progress, - phase="add(retro)", + phase="retro", ) return update_top_level_metadata(ts, date) @@ -747,7 +750,7 @@ def __iter__(self): def summary(self): return ( f"Group {self.sample_hash} {len(self.samples)} samples " - f"({dict(self.date_count)}) " + f"{dict(self.date_count)} " f"attaching at {path_summary(self.path)}, " f"immediate_reversions={self.immediate_reversions}, " f"additional_keys={self.additional_keys};" @@ -762,6 +765,8 @@ def add_matching_results( date, min_group_size=1, min_different_dates=1, + min_root_mutations=0, + additional_node_flags=None, show_progress=False, additional_group_metadata_keys=list(), phase=None, @@ -804,13 +809,16 @@ def add_matching_results( tables = ts.dump_tables() attach_nodes = [] - with get_progress(groups, date, phase, show_progress) as bar: + with get_progress(groups, date, f"add({phase})", show_progress) as bar: for group in bar: if ( len(group) < min_group_size or len(group.date_count) < min_different_dates ): - logger.debug(f"Skipping {group.summary()}") + logger.debug( + f"Skipping size={len(group)} dates={len(group.date_count)}: " + f"{group.summary()}" + ) continue for sample in group: @@ -824,17 +832,27 @@ def add_matching_results( # print(binary_ts.draw_text()) # print(binary_ts.tables.mutations) poly_ts = trim_branches(binary_ts) + # print(poly_ts.draw_text()) # print(poly_ts.tables.mutations) # print("----") assert poly_ts.num_samples == flat_ts.num_samples tree = poly_ts.first() + num_root_mutations = np.sum(poly_ts.mutations_node == tree.root) + num_recurrent_mutations = np.sum(poly_ts.mutations_parent != -1) + if num_root_mutations < min_root_mutations: + logger.debug( + f"Skipping root_mutations={num_root_mutations}: " + f"{group.summary()}" + ) + continue attach_depth = max(tree.depth(u) for u in poly_ts.samples()) - nodes = attach_tree(ts, tables, group, poly_ts, date) + nodes = attach_tree(ts, tables, group, poly_ts, date, additional_node_flags) logger.debug( - f"{group.summary()}; " - f"depth={attach_depth} mutations={poly_ts.num_mutations} " - f"attach_nodes={nodes}" + f"Attach {phase} {group.summary()}; " + f"depth={attach_depth} total_muts{poly_ts.num_mutations} " + f"root_muts={num_root_mutations} " + f"recurrent_muts={num_recurrent_mutations} attach_nodes={nodes}" ) attach_nodes.extend(nodes) @@ -1683,6 +1701,7 @@ def attach_tree( group, child_ts, date, + additional_node_flags, epsilon=None, ): attach_path = group.path @@ -1741,7 +1760,11 @@ def attach_tree( "date_added": date, } } - new_id = parent_tables.nodes.append(node.replace(time=time, metadata=metadata)) + new_id = parent_tables.nodes.append( + node.replace( + flags=node.flags | additional_node_flags, time=time, metadata=metadata + ) + ) node_id_map[node.id] = new_id for v in tree.children(u): parent_tables.edges.add_row( diff --git a/sc2ts/info.py b/sc2ts/info.py index c249324..0f301c3 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -325,6 +325,10 @@ def node_counts(self): pr_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_REVERSION_PUSH) re_nodes = np.sum(self.ts.nodes_flags == core.NODE_IS_RECOMBINANT) exact_matches = np.sum((self.ts.nodes_flags & core.NODE_IS_EXACT_MATCH) > 0) + sg_nodes = np.sum((self.ts.nodes_flags & core.NODE_IN_SAMPLE_GROUP) > 0) + rsg_nodes = np.sum( + (self.ts.nodes_flags & core.NODE_IN_RETROSPECTIVE_SAMPLE_GROUP) > 0 + ) immediate_reversion_marker = np.sum( (self.ts.nodes_flags & core.NODE_IS_IMMEDIATE_REVERSION_MARKER) > 0 ) @@ -336,6 +340,8 @@ def node_counts(self): "mc": mc_nodes, "pr": pr_nodes, "re": re_nodes, + "sg": sg_nodes, + "rsg": rsg_nodes, "imr": immediate_reversion_marker, "zero_muts": nodes_with_zero_muts, } diff --git a/tests/test_inference.py b/tests/test_inference.py index 6559356..9f81db5 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -447,6 +447,7 @@ def test_group(self, fx_ts_map, gid, date, internal, strains): md = node.metadata group = md["sc2ts"].get("group_id", None) if group == gid: + assert node.flags & sc2ts.NODE_IN_SAMPLE_GROUP > 0 if node.is_sample(): got_strains.append(md["strain"]) assert md["date"] == date