Skip to content

Commit

Permalink
Merge pull request #401 from robertknight/optimize-reuse-output
Browse files Browse the repository at this point in the history
Reuse existing output node when replacing operator with a fusion
  • Loading branch information
robertknight authored Nov 9, 2024
2 parents 05aceee + 4937d83 commit cf1ee74
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 67 deletions.
42 changes: 34 additions & 8 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,11 @@ impl Graph {
captures
}

pub fn add_node(&mut self, node: Node) -> NodeId {
/// Add a node to the graph and return its ID.
///
/// This contains the common logic for adding different types of node to
/// the graph.
fn add_node(&mut self, node: Node) -> NodeId {
let node_id = NodeId::from_u32(self.nodes.len() as u32);
self.nodes.push(node);

Expand All @@ -722,6 +726,13 @@ impl Graph {
node_id
}

/// Invalidate cached execution plans.
fn clear_cached_plan(&mut self) {
if let Ok(plan) = self.cached_plan.get_mut() {
plan.take();
}
}

/// Add an operator node to the graph.
///
/// `name` is an identifier for this node that is used in debug messages etc.
Expand All @@ -732,7 +743,10 @@ impl Graph {
/// operators.
///
/// `outputs` specifies which value nodes the operator's outputs should be
/// written to.
/// written to. If there is already an existing operator which uses the
/// same output, the new operator will become the source for this output
/// value. This enables replacing an operator while preserving metadata
/// of the output value (name, shape etc.).
///
/// Returns the ID of the operator node.
pub fn add_op(
Expand All @@ -753,6 +767,10 @@ impl Graph {
self.source_ids.insert(*output_id, op_id);
}

// Clear cached plan in case we just replaced the source operator for
// one of the output IDs.
self.clear_cached_plan();

op_id
}

Expand All @@ -776,7 +794,7 @@ impl Graph {
(op_node_id, op_out_id)
}

/// Add a constant node to the graph.
/// Convert `value` to a constant node and add it to the graph.
///
/// `name` is an identifier for this node that is used in debug messages etc.
///
Expand All @@ -786,12 +804,20 @@ impl Graph {
V: Into<ConstantNodeData<T>>,
ConstantNode<T>: Into<Constant>,
{
let node = ConstantNode {
name: name.map(|s| s.to_owned()),
data: value.into(),
};
self.add_constant_node(
ConstantNode {
name: name.map(|s| s.to_owned()),
data: value.into(),
}
.into(),
)
}

self.add_node(Node::Constant(node.into()))
/// Add a constant node to the graph.
///
/// Returns the ID of the added node.
pub fn add_constant_node(&mut self, node: Constant) -> NodeId {
self.add_node(Node::Constant(node))
}

/// Add a value node to the graph.
Expand Down
121 changes: 62 additions & 59 deletions src/optimize.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::error::Error;
use std::fmt::{Display, Formatter};

Expand Down Expand Up @@ -79,16 +78,25 @@ impl GraphMutator {
self.graph.add_constant(name, value)
}

/// Add a new constant value to the graph.
fn add_constant_node(&mut self, const_node: Constant) -> NodeId {
self.graph.add_constant_node(const_node)
}

/// Add a new operator to the graph with a single output node.
///
/// `op_output_id` specifies the ID of the output node. If not specified,
/// a new value node is created.
///
/// Returns the ID of the output node.
fn add_operator(
&mut self,
name: Option<&str>,
op: Box<dyn Operator + Send + Sync>,
inputs: &[Option<NodeId>],
op_output_id: Option<NodeId>,
) -> NodeId {
let op_output_id = self.graph.add_value(None, None);
let op_output_id = op_output_id.unwrap_or(self.graph.add_value(None, None));
let op_id = self.graph.add_op(name, op, inputs, &[Some(op_output_id)]);

for input_id in inputs.iter().filter_map(|id| *id) {
Expand All @@ -102,11 +110,6 @@ impl GraphMutator {
op_output_id
}

/// Add a new node to the graph.
fn add_node(&mut self, node: Node) -> NodeId {
self.graph.add_node(node)
}

/// Return a reference to the graph.
///
/// Note there is no mutable variant of this method. All graph updates must
Expand Down Expand Up @@ -135,30 +138,19 @@ impl GraphMutator {
&mut self,
create_fusion: F,
) {
let mut fusions = Vec::new();

for (op_node_id, op_node) in self.iter_operators() {
if let Some(fusion) = create_fusion(self, op_node_id, op_node) {
fusions.push(fusion);
}
}

// Map of old_output_id => new_output_id for subgraphs that have been
// replaced by fusions.
let mut replaced_ids = HashMap::<NodeId, NodeId>::new();

for mut fusion in fusions {
// Replace input IDs which match output IDs of previously applied
// fusions.
for input_id in fusion.input_ids.iter_mut().flatten() {
if let Some(replacement_id) = replaced_ids.get(input_id) {
*input_id = *replacement_id;
}
}

let (old_output_id, new_output_id) = fusion.apply(self);
let fusions: Vec<_> = self
.iter_operators()
.filter_map(|(op_node_id, op_node)| create_fusion(self, op_node_id, op_node))
.collect();

replaced_ids.insert(old_output_id, new_output_id);
for Fusion {
name,
fused_op,
input_ids,
output_id,
} in fusions
{
self.add_operator(name.as_deref(), fused_op, &input_ids, Some(output_id));
}
}

Expand Down Expand Up @@ -219,46 +211,27 @@ struct Fusion {
name: Option<String>,
fused_op: Box<dyn Operator + Send + Sync>,
input_ids: Vec<Option<NodeId>>,
old_output_id: NodeId,
output_id: NodeId,
}

impl Fusion {
/// Create a fusion with a given operator, name and input nodes.
///
/// `old_output_id` specifies the output ID of the subgraph that this fusion
/// `output_id` specifies the output ID of the subgraph that this fusion
/// replaces.
fn from_op<Op: Operator + Send + Sync>(
name: Option<&str>,
op: Op,
input_ids: Vec<Option<NodeId>>,
old_output_id: NodeId,
output_id: NodeId,
) -> Fusion {
Fusion {
name: name.map(|s| s.to_string()),
fused_op: Box::new(op),
input_ids,
old_output_id,
output_id,
}
}

/// Apply the fusion to the graph.
///
/// This adds the fused operator to the graph and replaces references to
/// the original output nodes with the fused operator's outputs.
///
/// Returns a tuple of `(old_output_id, new_output_id)`.
fn apply(self, graph: &mut GraphMutator) -> (NodeId, NodeId) {
let Fusion {
name,
fused_op,
input_ids,
old_output_id,
} = self;

let fused_op_output_id = graph.add_operator(name.as_deref(), fused_op, &input_ids);
graph.replace_value(old_output_id, fused_op_output_id);
(old_output_id, fused_op_output_id)
}
}

/// Utilities for matching patterns in a graph.
Expand Down Expand Up @@ -304,12 +277,11 @@ impl GraphOptimizer {

/// Apply optimizations to a graph.
///
/// The input and output nodes specified by `input_ids` and `output_ids`
/// will be preserved, but their IDs may change. Other nodes in the graph
/// may be modified, removed or replaced by optimization.
/// The graph's input and output nodes, identified by
/// [`input_ids`](Graph::input_ids) and [`output_ids`](Graph::output_ids)
/// will be preserved. Other nodes may be modified, removed or replaced.
///
/// This method returns the new graph along with the node IDs in the new
/// graph that correspond to `input_ids` and `output_ids`.
/// Returns the optimized graph.
pub fn optimize(
&self,
graph: Graph,
Expand Down Expand Up @@ -358,7 +330,7 @@ impl GraphOptimizer {
let mut new_captures: FxHashSet<_> = graph.graph().captures().iter().copied().collect();

for (capture_id, local_const) in captured_constants {
let const_id = graph.add_node(Node::Constant(local_const));
let const_id = graph.add_constant_node(local_const);
new_captures.remove(&capture_id);
graph.replace_value(capture_id, const_id);
}
Expand Down Expand Up @@ -867,6 +839,37 @@ mod tests {
assert_eq!(layer_norm.epsilon, Some(1e-6));
}

#[test]
fn test_optimize_preserves_input_output_nodes() {
let mut graph = Graph::new();

let input_1 = graph.add_value(None, None);
let input_2 = graph.add_value(None, None);

// Add fuse-able Transpose + MatMul
let (_, transpose_out) =
graph.add_simple_op("transpose", Transpose { perm: None }, &[input_1]);
let (_, matmul_out) = graph.add_simple_op("matmul", MatMul {}, &[transpose_out, input_2]);
graph.set_input_ids(&[input_1, input_2]);
graph.set_output_ids(&[matmul_out]);

let graph = optimize_graph(graph).unwrap();

// Verify that optimizer did change the graph
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "FusedTranspose(MatMul)");

// The IDs of the input and output nodes should be the same after
// optimization.
//
// The optimizer could have created new output nodes instead, but it
// would need to ensure that the new outputs preserved value node
// metadata (name, shape) from the original outputs.
assert_eq!(graph.input_ids(), &[input_1, input_2]);
assert_eq!(graph.output_ids(), &[matmul_out]);
assert_eq!(graph.node_name(matmul_out), "matmul_out");
}

#[test]
fn test_optimize_error() {
let mut graph = Graph::new();
Expand Down

0 comments on commit cf1ee74

Please sign in to comment.