diff --git a/core/src/model/order.rs b/core/src/model/order.rs index 1394b9e592..ab8a236e55 100644 --- a/core/src/model/order.rs +++ b/core/src/model/order.rs @@ -17,7 +17,7 @@ where } /// Find a working evaluation order for a list of nodes. -pub fn old_eval_order_for_nodes( +pub fn eval_order_for_nodes( nodes: &[Node], model_inputs: &[usize], model_outputs: &[usize], @@ -82,7 +82,7 @@ where } /// Find a working evaluation order for a list of nodes. -pub fn eval_order_for_nodes( +pub fn eval_order_for_nodes_memory( nodes: &[Node], _model_inputs: &[usize], model_outputs: &[usize], @@ -93,16 +93,19 @@ where O: Debug + Display + AsRef + AsMut + Clone + 'static, { let mut ups = vec![tvec!(); nodes.len()]; + let mut downs = vec![tvec!(); nodes.len()]; for (ix, node) in nodes.iter().enumerate() { for input in &node.inputs { if !ups[ix].contains(&input.node) { ups[ix].push(input.node); + downs[input.node].push(ix); } } } for (down, up) in more_dependencies { - if !ups[*down].contains(&up) { + if !ups[*down].contains(up) { ups[*down].push(*up); + downs[*up].push(*down); } } let costs: Vec = nodes @@ -113,7 +116,13 @@ where .map(|o| { o.fact .to_typed_fact() - .map(|f| f.datum_type.size_of() * f.shape.volume().to_usize().unwrap_or(0)) + .map(|f| { + f.datum_type.size_of() + * f.shape + .as_concrete() + .map(|dims| dims.iter().product()) + .unwrap_or(0) + }) .unwrap_or(0) }) .sum() @@ -131,22 +140,33 @@ where } } let mut order = vec![]; + let mut active = BitSet::with_capacity(nodes.len()); + let mut candidates = BitSet::with_capacity(nodes.len()); + candidates.extend(todo.iter().filter(|n| ups[*n].len() == 0)); while todo.len() > 0 { - let next = todo + let next = candidates .iter() .filter(|n| !ups[*n].iter().any(|up| todo.contains(*up))) .min_by_key(|&candidate| { - todo.iter() - .filter(|it| *it != candidate) - .flat_map(|down| ups[down].iter().copied()) - .filter(|up| *up == candidate || !todo.contains(*up)) - .unique() - .map(|n| costs[n]) - .sum::() + active.clear(); + active.extend( + todo.iter() + .filter(|it| *it != candidate) + .flat_map(|down| ups[down].iter().copied()) + .filter(|up| *up == candidate || !todo.contains(*up)), + ); + active.iter().map(|n| costs[n]).sum::() }) .context("Dependency loop detected.")?; order.push(next); todo.remove(next); + candidates.remove(next); + candidates.extend( + downs[next] + .iter() + .copied() + .filter(|n| todo.contains(*n) && ups[*n].iter().all(|up| !todo.contains(*up))), + ); } Ok(order) } diff --git a/core/src/plan.rs b/core/src/plan.rs index 95a80ff800..00c048b1bd 100644 --- a/core/src/plan.rs +++ b/core/src/plan.rs @@ -3,11 +3,12 @@ use std::fmt::{Debug, Display}; use std::marker::PhantomData; use crate::internal::*; -use crate::model::order::eval_order_for_nodes; use crate::model::{Fact, Graph, OutletId}; use crate::ops::konst::Const; use crate::ops::FrozenOpState; +use self::order::eval_order_for_nodes_memory; + #[derive(Default)] pub struct SessionState { pub inputs: HashMap, @@ -78,7 +79,7 @@ where let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::>(); let outputs_nodes = outputs.iter().map(|n| n.node).collect::>(); let mut order = - eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?; + eval_order_for_nodes_memory(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?; order.retain(|node| !model.borrow().node(*node).op_is::()); let mut values_needed_until_step = vec![0; model.borrow().nodes().len()]; for (step, node) in order.iter().enumerate() {