Skip to content

Commit

Permalink
Make Graph::add_node private again
Browse files Browse the repository at this point in the history
Exposing `add_node` as a public method creates a hazard as it allows callers to
add nodes to the graph without post-processing steps that other methods do (eg.
`add_op`). Make this method private again and add a more specific
`add_constant_node` alternative to handle the single use case for it outside the
graph module.
  • Loading branch information
robertknight committed Nov 9, 2024
1 parent 1df8de4 commit 04559ca
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
26 changes: 19 additions & 7 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,11 @@ impl Graph {
captures
}

pub fn add_node(&mut self, node: Node) -> NodeId {
/// Add a node to the graph and return its ID.
///
/// This contains the common logic for adding different types of node to
/// the graph.
fn add_node(&mut self, node: Node) -> NodeId {
let node_id = NodeId::from_u32(self.nodes.len() as u32);
self.nodes.push(node);

Expand Down Expand Up @@ -790,7 +794,7 @@ impl Graph {
(op_node_id, op_out_id)
}

/// Add a constant node to the graph.
/// Convert `value` to a constant node and add it to the graph.
///
/// `name` is an identifier for this node that is used in debug messages etc.
///
Expand All @@ -800,12 +804,20 @@ impl Graph {
V: Into<ConstantNodeData<T>>,
ConstantNode<T>: Into<Constant>,
{
let node = ConstantNode {
name: name.map(|s| s.to_owned()),
data: value.into(),
};
self.add_constant_node(
ConstantNode {
name: name.map(|s| s.to_owned()),
data: value.into(),
}
.into(),
)
}

self.add_node(Node::Constant(node.into()))
/// Add a constant node to the graph.
///
/// Returns the ID of the added node.
pub fn add_constant_node(&mut self, node: Constant) -> NodeId {
self.add_node(Node::Constant(node))
}

/// Add a value node to the graph.
Expand Down
12 changes: 6 additions & 6 deletions src/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ impl GraphMutator {
self.graph.add_constant(name, value)
}

/// Add a new constant value to the graph.
fn add_constant_node(&mut self, const_node: Constant) -> NodeId {
self.graph.add_constant_node(const_node)
}

/// Add a new operator to the graph with a single output node.
///
/// `op_output_id` specifies the ID of the output node. If not specified,
Expand Down Expand Up @@ -105,11 +110,6 @@ impl GraphMutator {
op_output_id
}

/// Add a new node to the graph.
fn add_node(&mut self, node: Node) -> NodeId {
self.graph.add_node(node)
}

/// Return a reference to the graph.
///
/// Note there is no mutable variant of this method. All graph updates must
Expand Down Expand Up @@ -330,7 +330,7 @@ impl GraphOptimizer {
let mut new_captures: FxHashSet<_> = graph.graph().captures().iter().copied().collect();

for (capture_id, local_const) in captured_constants {
let const_id = graph.add_node(Node::Constant(local_const));
let const_id = graph.add_constant_node(local_const);
new_captures.remove(&capture_id);
graph.replace_value(capture_id, const_id);
}
Expand Down

0 comments on commit 04559ca

Please sign in to comment.