Skip to content

Commit

Permalink
keep the simple variant except for plan
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieupoumeyrolsonos authored and kali committed Apr 11, 2024
1 parent 4a83533 commit dac1276
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 14 deletions.
44 changes: 32 additions & 12 deletions core/src/model/order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ where
}

/// Find a working evaluation order for a list of nodes.
pub fn old_eval_order_for_nodes<F, O>(
pub fn eval_order_for_nodes<F, O>(
nodes: &[Node<F, O>],
model_inputs: &[usize],
model_outputs: &[usize],
Expand Down Expand Up @@ -82,7 +82,7 @@ where
}

/// Find a working evaluation order for a list of nodes.
pub fn eval_order_for_nodes<F, O>(
pub fn eval_order_for_nodes_memory<F, O>(
nodes: &[Node<F, O>],
_model_inputs: &[usize],
model_outputs: &[usize],
Expand All @@ -93,16 +93,19 @@ where
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
let mut ups = vec![tvec!(); nodes.len()];
let mut downs = vec![tvec!(); nodes.len()];
for (ix, node) in nodes.iter().enumerate() {
for input in &node.inputs {
if !ups[ix].contains(&input.node) {
ups[ix].push(input.node);
downs[input.node].push(ix);
}
}
}
for (down, up) in more_dependencies {
if !ups[*down].contains(&up) {
if !ups[*down].contains(up) {
ups[*down].push(*up);
downs[*up].push(*down);
}
}
let costs: Vec<usize> = nodes
Expand All @@ -113,7 +116,13 @@ where
.map(|o| {
o.fact
.to_typed_fact()
.map(|f| f.datum_type.size_of() * f.shape.volume().to_usize().unwrap_or(0))
.map(|f| {
f.datum_type.size_of()
* f.shape
.as_concrete()
.map(|dims| dims.iter().product())
.unwrap_or(0)
})
.unwrap_or(0)
})
.sum()
Expand All @@ -131,22 +140,33 @@ where
}
}
let mut order = vec![];
let mut active = BitSet::with_capacity(nodes.len());
let mut candidates = BitSet::with_capacity(nodes.len());
candidates.extend(todo.iter().filter(|n| ups[*n].len() == 0));
while todo.len() > 0 {
let next = todo
let next = candidates
.iter()
.filter(|n| !ups[*n].iter().any(|up| todo.contains(*up)))
.min_by_key(|&candidate| {
todo.iter()
.filter(|it| *it != candidate)
.flat_map(|down| ups[down].iter().copied())
.filter(|up| *up == candidate || !todo.contains(*up))
.unique()
.map(|n| costs[n])
.sum::<usize>()
active.clear();
active.extend(
todo.iter()
.filter(|it| *it != candidate)
.flat_map(|down| ups[down].iter().copied())
.filter(|up| *up == candidate || !todo.contains(*up)),
);
active.iter().map(|n| costs[n]).sum::<usize>()
})
.context("Dependency loop detected.")?;
order.push(next);
todo.remove(next);
candidates.remove(next);
candidates.extend(
downs[next]
.iter()
.copied()
.filter(|n| todo.contains(*n) && ups[*n].iter().all(|up| !todo.contains(*up))),
);
}
Ok(order)
}
Expand Down
5 changes: 3 additions & 2 deletions core/src/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::fmt::{Debug, Display};
use std::marker::PhantomData;

use crate::internal::*;
use crate::model::order::eval_order_for_nodes;
use crate::model::{Fact, Graph, OutletId};
use crate::ops::konst::Const;
use crate::ops::FrozenOpState;

use self::order::eval_order_for_nodes_memory;

#[derive(Default)]
pub struct SessionState {
pub inputs: HashMap<usize, TValue>,
Expand Down Expand Up @@ -78,7 +79,7 @@ where
let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
let mut order =
eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?;
eval_order_for_nodes_memory(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?;
order.retain(|node| !model.borrow().node(*node).op_is::<Const>());
let mut values_needed_until_step = vec![0; model.borrow().nodes().len()];
for (step, node) in order.iter().enumerate() {
Expand Down

0 comments on commit dac1276

Please sign in to comment.