Skip to content

Commit

Permalink
🎵 up, down, turn aroud
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieupoumeyrolsonos authored and kali committed Apr 11, 2024
1 parent 1a398db commit 18217f6
Showing 1 changed file with 76 additions and 2 deletions.
78 changes: 76 additions & 2 deletions core/src/model/order.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Evaluation order for nodes.
use crate::internal::*;
use bit_set;
use bit_set::{self, BitSet};
use std::fmt::{Debug, Display};
use tract_data::itertools::Itertools;

/// Find an evaluation order for a model, using its default inputs and outputs
/// as boundaries.
Expand All @@ -16,7 +17,7 @@ where
}

/// Find a working evaluation order for a list of nodes.
pub fn eval_order_for_nodes<F, O>(
pub fn old_eval_order_for_nodes<F, O>(
nodes: &[Node<F, O>],
model_inputs: &[usize],
model_outputs: &[usize],
Expand Down Expand Up @@ -80,6 +81,79 @@ where
Ok(order)
}

/// Find a working evaluation order for a list of nodes.
pub fn eval_order_for_nodes<F, O>(
nodes: &[Node<F, O>],
_model_inputs: &[usize],
model_outputs: &[usize],
more_dependencies: &[(usize, usize)],
) -> TractResult<Vec<usize>>
where
F: Fact + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
let mut ups = vec![tvec!(); nodes.len()];
for (ix, node) in nodes.iter().enumerate() {
for input in &node.inputs {
if !ups[ix].contains(&input.node) {
ups[ix].push(input.node);
}
}
}
for (down, up) in more_dependencies {
if !ups[*down].contains(&up) {
ups[*down].push(*up);
}
}
let costs: Vec<usize> = nodes
.iter()
.map(|node| {
node.outputs
.iter()
.map(|o| {
o.fact
.to_typed_fact()
.map(|f| f.datum_type.size_of() * f.shape.volume().to_usize().unwrap_or(0))
.unwrap_or(0)
})
.sum()
})
.collect_vec();
let mut todo = bit_set::BitSet::with_capacity(nodes.len());
todo.extend(model_outputs.iter().copied());
loop {
let mut up: BitSet = todo.iter().flat_map(|n| ups[n].iter().copied()).collect::<BitSet>();
up.difference_with(&todo);
if up.len() == 0 {
break;
} else {
todo.union_with(&up);
}
}
let mut order = vec![];
while todo.len() > 0 {
let next = todo
.iter()
.filter(|n| !ups[*n].iter().any(|up| todo.contains(*up)))
.min_by_key(|&candidate| {
let mut state = todo.clone();
state.remove(candidate);
state
.iter()
.flat_map(|down| ups[down].iter().copied())
.filter(|up| !state.contains(*up))
.sorted()
.dedup()
.map(|n| costs[n])
.sum::<usize>()
})
.context("Dependency loop detected.")?;
order.push(next);
todo.remove(next);
}
Ok(order)
}

#[cfg(test)]
mod tests {
use crate::internal::*;
Expand Down

0 comments on commit 18217f6

Please sign in to comment.