diff --git a/src/graph.rs b/src/graph.rs index 4780bfcf..99486abe 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -711,7 +711,11 @@ impl Graph { captures } - pub fn add_node(&mut self, node: Node) -> NodeId { + /// Add a node to the graph and return its ID. + /// + /// This contains the common logic for adding different types of node to + /// the graph. + fn add_node(&mut self, node: Node) -> NodeId { let node_id = NodeId::from_u32(self.nodes.len() as u32); self.nodes.push(node); @@ -722,6 +726,13 @@ impl Graph { node_id } + /// Invalidate cached execution plans. + fn clear_cached_plan(&mut self) { + if let Ok(plan) = self.cached_plan.get_mut() { + plan.take(); + } + } + /// Add an operator node to the graph. /// /// `name` is an identifier for this node that is used in debug messages etc. @@ -732,7 +743,10 @@ impl Graph { /// operators. /// /// `outputs` specifies which value nodes the operator's outputs should be - /// written to. + /// written to. If there is already an existing operator which uses the + /// same output, the new operator will become the source for this output + /// value. This enables replacing an operator while preserving metadata + /// of the output value (name, shape etc.). /// /// Returns the ID of the operator node. pub fn add_op( @@ -753,6 +767,10 @@ impl Graph { self.source_ids.insert(*output_id, op_id); } + // Clear cached plan in case we just replaced the source operator for + // one of the output IDs. + self.clear_cached_plan(); + op_id } @@ -776,7 +794,7 @@ impl Graph { (op_node_id, op_out_id) } - /// Add a constant node to the graph. + /// Convert `value` to a constant node and add it to the graph. /// /// `name` is an identifier for this node that is used in debug messages etc. /// @@ -786,12 +804,20 @@ impl Graph { V: Into>, ConstantNode: Into, { - let node = ConstantNode { - name: name.map(|s| s.to_owned()), - data: value.into(), - }; + self.add_constant_node( + ConstantNode { + name: name.map(|s| s.to_owned()), + data: value.into(), + } + .into(), + ) + } - self.add_node(Node::Constant(node.into())) + /// Add a constant node to the graph. + /// + /// Returns the ID of the added node. + pub fn add_constant_node(&mut self, node: Constant) -> NodeId { + self.add_node(Node::Constant(node)) } /// Add a value node to the graph. diff --git a/src/optimize.rs b/src/optimize.rs index 1776b54b..fd035abb 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::error::Error; use std::fmt::{Display, Formatter}; @@ -79,16 +78,25 @@ impl GraphMutator { self.graph.add_constant(name, value) } + /// Add a new constant value to the graph. + fn add_constant_node(&mut self, const_node: Constant) -> NodeId { + self.graph.add_constant_node(const_node) + } + /// Add a new operator to the graph with a single output node. /// + /// `op_output_id` specifies the ID of the output node. If not specified, + /// a new value node is created. + /// /// Returns the ID of the output node. fn add_operator( &mut self, name: Option<&str>, op: Box, inputs: &[Option], + op_output_id: Option, ) -> NodeId { - let op_output_id = self.graph.add_value(None, None); + let op_output_id = op_output_id.unwrap_or(self.graph.add_value(None, None)); let op_id = self.graph.add_op(name, op, inputs, &[Some(op_output_id)]); for input_id in inputs.iter().filter_map(|id| *id) { @@ -102,11 +110,6 @@ impl GraphMutator { op_output_id } - /// Add a new node to the graph. - fn add_node(&mut self, node: Node) -> NodeId { - self.graph.add_node(node) - } - /// Return a reference to the graph. /// /// Note there is no mutable variant of this method. All graph updates must @@ -135,30 +138,19 @@ impl GraphMutator { &mut self, create_fusion: F, ) { - let mut fusions = Vec::new(); - - for (op_node_id, op_node) in self.iter_operators() { - if let Some(fusion) = create_fusion(self, op_node_id, op_node) { - fusions.push(fusion); - } - } - - // Map of old_output_id => new_output_id for subgraphs that have been - // replaced by fusions. - let mut replaced_ids = HashMap::::new(); - - for mut fusion in fusions { - // Replace input IDs which match output IDs of previously applied - // fusions. - for input_id in fusion.input_ids.iter_mut().flatten() { - if let Some(replacement_id) = replaced_ids.get(input_id) { - *input_id = *replacement_id; - } - } - - let (old_output_id, new_output_id) = fusion.apply(self); + let fusions: Vec<_> = self + .iter_operators() + .filter_map(|(op_node_id, op_node)| create_fusion(self, op_node_id, op_node)) + .collect(); - replaced_ids.insert(old_output_id, new_output_id); + for Fusion { + name, + fused_op, + input_ids, + output_id, + } in fusions + { + self.add_operator(name.as_deref(), fused_op, &input_ids, Some(output_id)); } } @@ -219,46 +211,27 @@ struct Fusion { name: Option, fused_op: Box, input_ids: Vec>, - old_output_id: NodeId, + output_id: NodeId, } impl Fusion { /// Create a fusion with a given operator, name and input nodes. /// - /// `old_output_id` specifies the output ID of the subgraph that this fusion + /// `output_id` specifies the output ID of the subgraph that this fusion /// replaces. fn from_op( name: Option<&str>, op: Op, input_ids: Vec>, - old_output_id: NodeId, + output_id: NodeId, ) -> Fusion { Fusion { name: name.map(|s| s.to_string()), fused_op: Box::new(op), input_ids, - old_output_id, + output_id, } } - - /// Apply the fusion to the graph. - /// - /// This adds the fused operator to the graph and replaces references to - /// the original output nodes with the fused operator's outputs. - /// - /// Returns a tuple of `(old_output_id, new_output_id)`. - fn apply(self, graph: &mut GraphMutator) -> (NodeId, NodeId) { - let Fusion { - name, - fused_op, - input_ids, - old_output_id, - } = self; - - let fused_op_output_id = graph.add_operator(name.as_deref(), fused_op, &input_ids); - graph.replace_value(old_output_id, fused_op_output_id); - (old_output_id, fused_op_output_id) - } } /// Utilities for matching patterns in a graph. @@ -304,12 +277,11 @@ impl GraphOptimizer { /// Apply optimizations to a graph. /// - /// The input and output nodes specified by `input_ids` and `output_ids` - /// will be preserved, but their IDs may change. Other nodes in the graph - /// may be modified, removed or replaced by optimization. + /// The graph's input and output nodes, identified by + /// [`input_ids`](Graph::input_ids) and [`output_ids`](Graph::output_ids) + /// will be preserved. Other nodes may be modified, removed or replaced. /// - /// This method returns the new graph along with the node IDs in the new - /// graph that correspond to `input_ids` and `output_ids`. + /// Returns the optimized graph. pub fn optimize( &self, graph: Graph, @@ -358,7 +330,7 @@ impl GraphOptimizer { let mut new_captures: FxHashSet<_> = graph.graph().captures().iter().copied().collect(); for (capture_id, local_const) in captured_constants { - let const_id = graph.add_node(Node::Constant(local_const)); + let const_id = graph.add_constant_node(local_const); new_captures.remove(&capture_id); graph.replace_value(capture_id, const_id); } @@ -867,6 +839,37 @@ mod tests { assert_eq!(layer_norm.epsilon, Some(1e-6)); } + #[test] + fn test_optimize_preserves_input_output_nodes() { + let mut graph = Graph::new(); + + let input_1 = graph.add_value(None, None); + let input_2 = graph.add_value(None, None); + + // Add fuse-able Transpose + MatMul + let (_, transpose_out) = + graph.add_simple_op("transpose", Transpose { perm: None }, &[input_1]); + let (_, matmul_out) = graph.add_simple_op("matmul", MatMul {}, &[transpose_out, input_2]); + graph.set_input_ids(&[input_1, input_2]); + graph.set_output_ids(&[matmul_out]); + + let graph = optimize_graph(graph).unwrap(); + + // Verify that optimizer did change the graph + let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap(); + assert_eq!(op.operator().name(), "FusedTranspose(MatMul)"); + + // The IDs of the input and output nodes should be the same after + // optimization. + // + // The optimizer could have created new output nodes instead, but it + // would need to ensure that the new outputs preserved value node + // metadata (name, shape) from the original outputs. + assert_eq!(graph.input_ids(), &[input_1, input_2]); + assert_eq!(graph.output_ids(), &[matmul_out]); + assert_eq!(graph.node_name(matmul_out), "matmul_out"); + } + #[test] fn test_optimize_error() { let mut graph = Graph::new();