From 809ba69dbecfd7c7aea56cc1e46c1f6678023659 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sat, 9 Nov 2024 05:34:46 +0000 Subject: [PATCH 1/3] Reuse existing output node when replacing operator with a fusion When replacing a subgraph with a fused operator, reuse the output value node from the subgraph instead of creating a new one. This preserves metadata such as the name and shape associated with that value node. Also it simplifies the code by removing the need to replace all references to the previous output node with the new one. This improves the runtime of the Whisper example using the whisper-base model by ~25% by fixing an issue where fused Transpose + MatMul operations did not get used. --- src/graph.rs | 16 +++++++- src/optimize.rs | 100 +++++++++++++++++++++++++----------------------- 2 files changed, 67 insertions(+), 49 deletions(-) diff --git a/src/graph.rs b/src/graph.rs index 4780bfcf..cb83b64e 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -722,6 +722,13 @@ impl Graph { node_id } + /// Invalidate cached execution plans. + fn clear_cached_plan(&mut self) { + if let Ok(plan) = self.cached_plan.get_mut() { + plan.take(); + } + } + /// Add an operator node to the graph. /// /// `name` is an identifier for this node that is used in debug messages etc. @@ -732,7 +739,10 @@ impl Graph { /// operators. /// /// `outputs` specifies which value nodes the operator's outputs should be - /// written to. + /// written to. If there is already an existing operator which uses the + /// same output, the new operator will become the source for this output + /// value. This enables replacing an operator while preserving metadata + /// of the output value (name, shape etc.). /// /// Returns the ID of the operator node. pub fn add_op( @@ -753,6 +763,10 @@ impl Graph { self.source_ids.insert(*output_id, op_id); } + // Clear cached plan in case we just replaced the source operator for + // one of the output IDs. + self.clear_cached_plan(); + op_id } diff --git a/src/optimize.rs b/src/optimize.rs index 1776b54b..e117af5c 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::error::Error; use std::fmt::{Display, Formatter}; @@ -81,14 +80,18 @@ impl GraphMutator { /// Add a new operator to the graph with a single output node. /// + /// `op_output_id` specifies the ID of the output node. If not specified, + /// a new value node is created. + /// /// Returns the ID of the output node. fn add_operator( &mut self, name: Option<&str>, op: Box, inputs: &[Option], + op_output_id: Option, ) -> NodeId { - let op_output_id = self.graph.add_value(None, None); + let op_output_id = op_output_id.unwrap_or(self.graph.add_value(None, None)); let op_id = self.graph.add_op(name, op, inputs, &[Some(op_output_id)]); for input_id in inputs.iter().filter_map(|id| *id) { @@ -135,30 +138,19 @@ impl GraphMutator { &mut self, create_fusion: F, ) { - let mut fusions = Vec::new(); - - for (op_node_id, op_node) in self.iter_operators() { - if let Some(fusion) = create_fusion(self, op_node_id, op_node) { - fusions.push(fusion); - } - } - - // Map of old_output_id => new_output_id for subgraphs that have been - // replaced by fusions. - let mut replaced_ids = HashMap::::new(); - - for mut fusion in fusions { - // Replace input IDs which match output IDs of previously applied - // fusions. - for input_id in fusion.input_ids.iter_mut().flatten() { - if let Some(replacement_id) = replaced_ids.get(input_id) { - *input_id = *replacement_id; - } - } - - let (old_output_id, new_output_id) = fusion.apply(self); + let fusions: Vec<_> = self + .iter_operators() + .filter_map(|(op_node_id, op_node)| create_fusion(self, op_node_id, op_node)) + .collect(); - replaced_ids.insert(old_output_id, new_output_id); + for Fusion { + name, + fused_op, + input_ids, + output_id, + } in fusions + { + self.add_operator(name.as_deref(), fused_op, &input_ids, Some(output_id)); } } @@ -219,46 +211,27 @@ struct Fusion { name: Option, fused_op: Box, input_ids: Vec>, - old_output_id: NodeId, + output_id: NodeId, } impl Fusion { /// Create a fusion with a given operator, name and input nodes. /// - /// `old_output_id` specifies the output ID of the subgraph that this fusion + /// `output_id` specifies the output ID of the subgraph that this fusion /// replaces. fn from_op( name: Option<&str>, op: Op, input_ids: Vec>, - old_output_id: NodeId, + output_id: NodeId, ) -> Fusion { Fusion { name: name.map(|s| s.to_string()), fused_op: Box::new(op), input_ids, - old_output_id, + output_id, } } - - /// Apply the fusion to the graph. - /// - /// This adds the fused operator to the graph and replaces references to - /// the original output nodes with the fused operator's outputs. - /// - /// Returns a tuple of `(old_output_id, new_output_id)`. - fn apply(self, graph: &mut GraphMutator) -> (NodeId, NodeId) { - let Fusion { - name, - fused_op, - input_ids, - old_output_id, - } = self; - - let fused_op_output_id = graph.add_operator(name.as_deref(), fused_op, &input_ids); - graph.replace_value(old_output_id, fused_op_output_id); - (old_output_id, fused_op_output_id) - } } /// Utilities for matching patterns in a graph. @@ -867,6 +840,37 @@ mod tests { assert_eq!(layer_norm.epsilon, Some(1e-6)); } + #[test] + fn test_optimize_preserves_input_output_nodes() { + let mut graph = Graph::new(); + + let input_1 = graph.add_value(None, None); + let input_2 = graph.add_value(None, None); + + // Add fuse-able Transpose + MatMul + let (_, transpose_out) = + graph.add_simple_op("transpose", Transpose { perm: None }, &[input_1]); + let (_, matmul_out) = graph.add_simple_op("matmul", MatMul {}, &[transpose_out, input_2]); + graph.set_input_ids(&[input_1, input_2]); + graph.set_output_ids(&[matmul_out]); + + let graph = optimize_graph(graph).unwrap(); + + // Verify that optimizer did change the graph + let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap(); + assert_eq!(op.operator().name(), "FusedTranspose(MatMul)"); + + // The IDs of the input and output nodes should be the same after + // optimization. + // + // The optimizer could have created new output nodes instead, but it + // would need to ensure that the new outputs preserved value node + // metadata (name, shape) from the original outputs. + assert_eq!(graph.input_ids(), &[input_1, input_2]); + assert_eq!(graph.output_ids(), &[matmul_out]); + assert_eq!(graph.node_name(matmul_out), "matmul_out"); + } + #[test] fn test_optimize_error() { let mut graph = Graph::new(); From 6a3abb41e8b703e516cdc24eee10ac557350acd5 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sat, 9 Nov 2024 05:34:51 +0000 Subject: [PATCH 2/3] Correct outdated comment for `GraphOptimizer::optimize` --- src/optimize.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/optimize.rs b/src/optimize.rs index e117af5c..79465a89 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -277,12 +277,11 @@ impl GraphOptimizer { /// Apply optimizations to a graph. /// - /// The input and output nodes specified by `input_ids` and `output_ids` - /// will be preserved, but their IDs may change. Other nodes in the graph - /// may be modified, removed or replaced by optimization. + /// The graph's input and output nodes, identified by + /// [`input_ids`](Graph::input_ids) and [`output_ids`](Graph::output_ids) + /// will be preserved. Other nodes may be modified, removed or replaced. /// - /// This method returns the new graph along with the node IDs in the new - /// graph that correspond to `input_ids` and `output_ids`. + /// Returns the optimized graph. pub fn optimize( &self, graph: Graph, From 4937d834ffbeb74a9136df5648bc6e1f82facb78 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sat, 9 Nov 2024 06:14:06 +0000 Subject: [PATCH 3/3] Make `Graph::add_node` private again Exposing `add_node` as a public method creates a hazard as it allows callers to add nodes to the graph without post-processing steps that other methods do (eg. `add_op`). Make this method private again and add a more specific `add_constant_node` alternative to handle the single use case for it outside the graph module. --- src/graph.rs | 26 +++++++++++++++++++------- src/optimize.rs | 12 ++++++------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/graph.rs b/src/graph.rs index cb83b64e..99486abe 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -711,7 +711,11 @@ impl Graph { captures } - pub fn add_node(&mut self, node: Node) -> NodeId { + /// Add a node to the graph and return its ID. + /// + /// This contains the common logic for adding different types of node to + /// the graph. + fn add_node(&mut self, node: Node) -> NodeId { let node_id = NodeId::from_u32(self.nodes.len() as u32); self.nodes.push(node); @@ -790,7 +794,7 @@ impl Graph { (op_node_id, op_out_id) } - /// Add a constant node to the graph. + /// Convert `value` to a constant node and add it to the graph. /// /// `name` is an identifier for this node that is used in debug messages etc. /// @@ -800,12 +804,20 @@ impl Graph { V: Into>, ConstantNode: Into, { - let node = ConstantNode { - name: name.map(|s| s.to_owned()), - data: value.into(), - }; + self.add_constant_node( + ConstantNode { + name: name.map(|s| s.to_owned()), + data: value.into(), + } + .into(), + ) + } - self.add_node(Node::Constant(node.into())) + /// Add a constant node to the graph. + /// + /// Returns the ID of the added node. + pub fn add_constant_node(&mut self, node: Constant) -> NodeId { + self.add_node(Node::Constant(node)) } /// Add a value node to the graph. diff --git a/src/optimize.rs b/src/optimize.rs index 79465a89..fd035abb 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -78,6 +78,11 @@ impl GraphMutator { self.graph.add_constant(name, value) } + /// Add a new constant value to the graph. + fn add_constant_node(&mut self, const_node: Constant) -> NodeId { + self.graph.add_constant_node(const_node) + } + /// Add a new operator to the graph with a single output node. /// /// `op_output_id` specifies the ID of the output node. If not specified, @@ -105,11 +110,6 @@ impl GraphMutator { op_output_id } - /// Add a new node to the graph. - fn add_node(&mut self, node: Node) -> NodeId { - self.graph.add_node(node) - } - /// Return a reference to the graph. /// /// Note there is no mutable variant of this method. All graph updates must @@ -330,7 +330,7 @@ impl GraphOptimizer { let mut new_captures: FxHashSet<_> = graph.graph().captures().iter().copied().collect(); for (capture_id, local_const) in captured_constants { - let const_id = graph.add_node(Node::Constant(local_const)); + let const_id = graph.add_constant_node(local_const); new_captures.remove(&capture_id); graph.replace_value(capture_id, const_id); }