Skip to content

Commit

Permalink
Merge pull request #1250 from nextstrain/refactor/tree-builder-finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-aksamentov authored Sep 12, 2023
2 parents 15757ec + 894409f commit 4e5feb7
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 89 deletions.
16 changes: 16 additions & 0 deletions packages_rs/nextclade/src/analyze/divergence.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::analyze::nuc_sub::NucSub;
use crate::coord::range::NucRefGlobalRange;
use crate::tree::params::TreeBuilderParams;
use crate::tree::tree::DivergenceUnits;

/// Calculate number of nuc muts, only considering ACGT characters
Expand All @@ -24,3 +26,17 @@ pub fn calculate_branch_length(

this_div
}

/// Calculate nuc mut score
pub fn score_nuc_muts(nuc_muts: &[NucSub], masked_ranges: &[NucRefGlobalRange], params: &TreeBuilderParams) -> f64 {
// Only consider ACGT characters
let nuc_muts = nuc_muts.iter().filter(|m| m.ref_nuc.is_acgt() && m.qry_nuc.is_acgt());

// Split away masked mutations
let (masked_muts, muts): (Vec<_>, Vec<_>) =
nuc_muts.partition(|m| masked_ranges.iter().any(|range| range.contains(m.pos)));

let n_muts = muts.len() as f64;
let n_masked_muts = masked_muts.len() as f64;
n_muts + n_masked_muts * params.masked_muts_weight
}
2 changes: 1 addition & 1 deletion packages_rs/nextclade/src/run/nextclade_wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl Nextclade {

pub fn get_output_trees(&mut self, results: Vec<NextcladeOutputs>) -> Result<Option<OutputTrees>, Report> {
if let Some(graph) = &mut self.graph {
graph_attach_new_nodes_in_place(graph, results, self.ref_seq.len(), &self.params.tree_builder)?;
graph_attach_new_nodes_in_place(graph, results, self.ref_seq.len(), &self.params.tree_builder)?;
let auspice = convert_graph_to_auspice_tree(graph)?;
let nwk = convert_graph_to_nwk_string(graph)?;
Ok(Some(OutputTrees { auspice, nwk }))
Expand Down
4 changes: 4 additions & 0 deletions packages_rs/nextclade/src/tree/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ pub struct TreeBuilderParams {
#[clap(long)]
#[clap(num_args=0..=1, default_missing_value = "true")]
pub without_greedy_tree_builder: bool,

#[clap(long)]
pub masked_muts_weight: f64,
}

#[allow(clippy::derivable_impls)]
impl Default for TreeBuilderParams {
fn default() -> Self {
Self {
without_greedy_tree_builder: false,
masked_muts_weight: 0.05,
}
}
}
4 changes: 2 additions & 2 deletions packages_rs/nextclade/src/tree/split_muts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ where
pub fn difference_of_muts(left: &BranchMutations, right: &BranchMutations) -> Result<BranchMutations, Report> {
Ok(BranchMutations {
nuc_muts: difference(&left.nuc_muts, &right.nuc_muts)
.wrap_err("When calculating union of private nucleotide substitutions")?,
.wrap_err("When calculating difference of private nucleotide substitutions")?,
aa_muts: difference_of_aa_muts(&left.aa_muts, &right.aa_muts)
.wrap_err("When calculating union of private aminoacid mutations")?,
.wrap_err("When calculating difference of private aminoacid mutations")?,
})
}

Expand Down
206 changes: 120 additions & 86 deletions packages_rs/nextclade/src/tree/tree_builder.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::analyze::aa_del::AaDel;
use crate::analyze::aa_sub::AaSub;
use crate::analyze::divergence::{calculate_branch_length, count_nuc_muts};
use crate::analyze::divergence::{calculate_branch_length, count_nuc_muts, score_nuc_muts};
use crate::analyze::find_private_nuc_mutations::BranchMutations;
use crate::analyze::nuc_del::NucDel;
use crate::analyze::nuc_sub::NucSub;
use crate::graph::node::GraphNodeKey;
use crate::make_internal_report;
use crate::coord::range::NucRefGlobalRange;
use crate::graph::node::{GraphNodeKey, Node};
use crate::tree::params::TreeBuilderParams;
use crate::tree::split_muts::{difference_of_muts, split_muts, union_of_muts, SplitMutsResult};
use crate::tree::tree::{AuspiceGraph, AuspiceGraphEdgePayload, AuspiceGraphNodePayload, TreeBranchAttrsLabels};
Expand Down Expand Up @@ -87,7 +87,7 @@ pub fn graph_attach_new_node_in_place(
} else {
// for the attachment on the reference tree ('result') fine tune the position
// on the updated graph to minimize the number of private mutations
finetune_nearest_node(graph, result.nearest_node_id, &mutations_seq)?
finetune_nearest_node(graph, result.nearest_node_id, &mutations_seq, params)?
};

// add the new node at the fine tuned position while accounting for shared mutations
Expand All @@ -97,109 +97,143 @@ pub fn graph_attach_new_node_in_place(
Ok(())
}

/// Moves the new sequences, defined by its set of private mutations
/// along the tree starting at the `nearest_node`. As the new sequence is moved, the
/// private mutations are updated. This is repeated until the number of private mutations (nuc)
/// can not be reduced further by moving the node. At the end of the loop, the nearest node
/// is either the closest possible point, or this closest point is along the branch leading
/// to the nearest_node.
pub fn finetune_nearest_node(
graph: &AuspiceGraph,
nearest_node_key: GraphNodeKey,
seq_private_mutations: &BranchMutations,
params: &TreeBuilderParams,
) -> Result<(GraphNodeKey, BranchMutations), Report> {
let mut current_best_node = graph.get_node(nearest_node_key)?;
let masked_ranges = graph.data.meta.placement_mask_ranges();
let mut best_node = graph.get_node(nearest_node_key)?;
let mut private_mutations = seq_private_mutations.clone();

// the following loop moves the new sequences, defined by its set of private mutations
// along the tree starting at the `nearest_node`. As the new sequence is moved, the
// private mutations are updated. This is repeated until the number of private mutations (nuc)
// can not be reduced further by moving the node. At the end of the loop, the nearest node
// is either the closest possible point, or this closest point is along the branch leading
// to the nearest_node.
loop {
// in each iteration, check how many mutation are shared with the branch leading to
// the current_best_node or any of its children (loop further below).
let mut best_node = current_best_node;
let (mut best_split_result, mut n_shared_muts) = if current_best_node.is_root() {
// don't include node if node is root as we don't attach nodes above the root
let best_split_result = SplitMutsResult {
left: private_mutations.clone(),
right: BranchMutations::default(),
shared: BranchMutations::default(),
};
(best_split_result, 0)
} else {
let best_split_result = split_muts(
&current_best_node.payload().tmp.private_mutations.invert(),
&private_mutations,
)
.wrap_err_with(|| {
// Check how many mutations are shared with the branch leading to the current_best_node or any of its children
let (candidate_node, candidate_split, shared_muts_score) =
find_shared_muts(graph, best_node, &private_mutations, masked_ranges, params).wrap_err_with(|| {
format!(
"When splitting mutations between query sequence and the nearest node '{}'",
current_best_node.payload().name
"When calculating shared mutations against the current best node '{}'",
best_node.payload().name
)
})?;
let n_shared_muts = count_nuc_muts(&best_split_result.shared.nuc_muts);
(best_split_result, n_shared_muts)
};

// check all child nodes for shared mutations
for child in graph.iter_children_of(current_best_node) {
let tmp_split_result =
split_muts(&child.payload().tmp.private_mutations, &private_mutations).wrap_err_with(|| {
format!(
"When splitting mutations between query sequence and the child node '{}'",
child.payload().name
)
})?;
let tmp_n_shared_muts = count_nuc_muts(&tmp_split_result.shared.nuc_muts);
if tmp_n_shared_muts > n_shared_muts {
n_shared_muts = tmp_n_shared_muts;
best_split_result = tmp_split_result;
best_node = child;
}
// Check if the new candidate node is better than the current best
let left_muts_score = score_nuc_muts(&candidate_split.left.nuc_muts, masked_ranges, params);
match find_better_node_maybe(graph, best_node, candidate_node, shared_muts_score, left_muts_score) {
None => break,
Some(better_node) => best_node = better_node,
}

// if shared mutations are found, the current_best_node is updated
if n_shared_muts > 0 {
if best_node.key() == current_best_node.key() && best_split_result.left.nuc_muts.is_empty() {
// All mutations from the parent to the node are shared with private mutations. Move up to the parent.
// FIXME: what if there's no parent?
current_best_node = graph
.parent_of_by_key(best_node.key())
.ok_or_else(|| make_internal_report!("Parent node is expected, but not found"))?;
} else if best_node.key() == current_best_node.key() {
// The best node is the current node. Break.
break;
} else {
// The best node is child
current_best_node = graph.get_node(best_node.key())?;
}
} else if current_best_node.is_leaf()
&& !current_best_node.is_root()
&& current_best_node.payload().tmp.private_mutations.nuc_muts.is_empty()
{
current_best_node = graph
.parent_of_by_key(best_node.key())
.ok_or_else(|| make_internal_report!("Parent node is expected, but not found"))?;
} else {
break;
}
// update the private mutations to match the new 'current_best_node'. This involves
// in step 1 subtracting the shared mutations from the private mutations struct
private_mutations = difference_of_muts(&private_mutations, &best_split_result.shared).wrap_err_with(|| {
// Update query mutations to adjust for the new position of the placed node
private_mutations = update_private_mutations(&private_mutations, &candidate_split).wrap_err_with(|| {
format!(
"When calculating difference of mutations between query sequence and the branch leading to the next attachment point '{}'",
current_best_node.payload().name
"When updating private mutations against the current best node '{}'",
best_node.payload().name
)
})?;
// in step 2 we need to add the inverted remaining mutations on that branch.
// Not that this can be necessary even if there are no left-over nuc_subs.
// Amino acid mutations can be decoupled from the their nucleotide mutations or
// changes in the amino acid sequences due to mutations in the same codon still need handling
private_mutations = union_of_muts(&private_mutations, &best_split_result.left.invert()).wrap_err_with(|| {
}

Ok((best_node.key(), private_mutations))
}

/// Check how many mutations are shared with the branch leading to the current_best_node or any of its children
fn find_shared_muts<'g>(
graph: &'g AuspiceGraph,
best_node: &'g Node<AuspiceGraphNodePayload>,
private_mutations: &BranchMutations,
masked_ranges: &[NucRefGlobalRange],
params: &TreeBuilderParams,
) -> Result<(&'g Node<AuspiceGraphNodePayload>, SplitMutsResult, f64), Report> {
let (mut candidate_split, mut shared_muts_score) = if best_node.is_root() {
// Don't include node if node is root as we don't attach nodes above the root
let candidate_split = SplitMutsResult {
left: BranchMutations::default(),
right: private_mutations.clone(),
shared: BranchMutations::default(),
};
(candidate_split, 0.0)
} else {
let candidate_split = split_muts(&best_node.payload().tmp.private_mutations.invert(), private_mutations)
.wrap_err_with(|| {
format!(
"When splitting mutations between query sequence and the nearest node '{}'",
best_node.payload().name
)
})?;
let shared_muts_score = score_nuc_muts(&candidate_split.shared.nuc_muts, masked_ranges, params);
(candidate_split, shared_muts_score)
};

// Check all child nodes for shared mutations
let mut candidate_node = best_node;
for child in graph.iter_children_of(best_node) {
let child_split = split_muts(&child.payload().tmp.private_mutations, private_mutations).wrap_err_with(|| {
format!(
"When calculating union of mutations between query sequence and the branch leading to the next attachment point '{}'",
best_node.payload().name
"When splitting mutations between query sequence and the child node '{}'",
child.payload().name
)
})?;
let child_shared_muts_score = score_nuc_muts(&child_split.shared.nuc_muts, masked_ranges, params);
if child_shared_muts_score > shared_muts_score {
shared_muts_score = child_shared_muts_score;
candidate_split = child_split;
candidate_node = child;
}
}
Ok((current_best_node.key(), private_mutations))
Ok((candidate_node, candidate_split, shared_muts_score))
}

/// Find out if the candidate node is better than the current best (with caveats).
/// Return a better node or `None` (if the current best node is to be preserved).
fn find_better_node_maybe<'g>(
graph: &'g AuspiceGraph,
best_node: &'g Node<AuspiceGraphNodePayload>,
candidate_node: &'g Node<AuspiceGraphNodePayload>,
shared_muts_score: f64,
left_muts_score: f64,
) -> Option<&'g Node<AuspiceGraphNodePayload>> {
if candidate_node == best_node {
// best node is the node itself. Move up the tree if all mutations between
// the candidate node and its parent are also in the private mutations.
// This covers the case where the candidate is a leaf with zero length branch
// as the .left.nuc_muts is empty in that case
if left_muts_score == 0.0 {
return graph.parent_of(candidate_node);
}
} else if shared_muts_score > 0.0 {
// candidate node is child node, move to child node if there are shared mutations
// this should always be the case if the candidate node != best_node
return Some(candidate_node);
}
// no improvement possible. Return None to stay
None
}

/// Update private mutations to match the new best node
fn update_private_mutations(
private_mutations: &BranchMutations,
best_split_result: &SplitMutsResult,
) -> Result<BranchMutations, Report> {
// Step 1: subtract shared mutations from private mutations
let private_mutations = difference_of_muts(private_mutations, &best_split_result.shared).wrap_err(
"When calculating difference of mutations between query sequence and the branch leading to the next attachment point"
)?;

// Step 2: We need to add the inverted remaining mutations on that branch.
// Note that this can be necessary even if there are no left-over nuc_subs.
// Amino acid mutations can be decoupled from the their nucleotide mutations or
// changes in the amino acid sequences due to mutations in the same codon still need handling.
let private_mutations = union_of_muts(&private_mutations, &best_split_result.left.invert()).wrap_err(
"When calculating union of mutations between query sequence and the branch leading to the next attachment point.",
)?;

Ok(private_mutations)
}

pub fn attach_to_internal_node(
Expand Down

0 comments on commit 4e5feb7

Please sign in to comment.