Skip to content

Commit

Permalink
[IR] Rename Value.consumers() to Value.uses() (#1422)
Browse files Browse the repository at this point in the history
A second renaming PR, hopefully more stable afterward... 

1. The consumer() naming is not as accurate as it can be, because while
a `Node` is a consumer/user, a `(Node, index)` is a "use". At the same
time ONNX C++ IR
(https://github.com/onnx/onnx/blob/c459890aa266d74ca612f334d0f3e869dcdbb597/onnx/common/ir.h#L288)
as well as torch fx
(https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/fx/node.py#L226)
both uses the `user` / `use` concept. FX calls it `user` because it is a
Node-Node relationship. We have a Node-(Node (user), index)
relationship.
  • Loading branch information
justinchuby authored Apr 23, 2024
1 parent 7a30f37 commit 8060e2d
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 40 deletions.
6 changes: 3 additions & 3 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def replace_all_uses_with(
values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
) -> None:
"""Replace all consumers of the given values with the replacements.
"""Replace all uses of the given values with the replacements.
This is useful when nodes in the graph are replaced with new nodes, where
the old users need to be updated to use the outputs of the new nodes.
Expand All @@ -194,7 +194,7 @@ def replace_all_uses_with(
1
>>> node_c.inputs[0].producer().op_type
'D'
>>> len(node_a.outputs[0].consumers())
>>> len(node_a.outputs[0].uses())
0
When values and replacements are sequences, they are zipped into pairs. All
Expand All @@ -216,5 +216,5 @@ def replace_all_uses_with(
if len(values) != len(replacements):
raise ValueError("The number of values and replacements must match.")
for value, replacement in zip(values, replacements):
for user_node, index in tuple(value.consumers()):
for user_node, index in tuple(value.uses()):
user_node.replace_input_with(index, replacement)
56 changes: 28 additions & 28 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,11 +635,11 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
user is responsible to call ``graph.append(node)`` (or other mutation methods
in :class:`Graph`) to add the node to the graph.
After the node is initialized, it will add itself as a consumer of the input values.
After the node is initialized, it will add itself as a user of the input values.
The output values of the node are created during node initialization and are immutable.
To change the output values, create a new node and replace the each of the inputs of ``output.consumers`` with
the new output values by calling :meth:`replace_input_with` on the consumer nodes
To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
the new output values by calling :meth:`replace_input_with` on the using nodes
of this node's outputs.
"""

Expand Down Expand Up @@ -673,7 +673,7 @@ def __init__(
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
):
"""Initialize a node and add it as a consumer of the input values.
"""Initialize a node and add it as a user of the input values.
Args:
domain: The domain of the operator. For onnx operators, this is an empty string.
Expand Down Expand Up @@ -718,10 +718,10 @@ def __init__(
self._graph: Graph | None = graph
self.doc_string = doc_string

# Add the node as a consumer of the inputs
# Add the node as a use of the inputs
for i, input_value in enumerate(self._inputs):
if input_value is not None:
input_value._add_consumer(self, i) # pylint: disable=protected-access
input_value._add_usage(self, i) # pylint: disable=protected-access

# Add the node to the graph if graph is specified
if self._graph is not None:
Expand Down Expand Up @@ -821,9 +821,9 @@ def replace_input_with(self, index: int, value: Value | None) -> None:
value if i == index else old_input for i, old_input in enumerate(self.inputs)
)
if old_input is not None:
old_input._remove_consumer(self, index) # pylint: disable=protected-access
old_input._remove_usage(self, index) # pylint: disable=protected-access
if value is not None:
value._add_consumer(self, index) # pylint: disable=protected-access
value._add_usage(self, index) # pylint: disable=protected-access

def prepend(self, /, nodes: Node | Iterable[Node]) -> None:
"""Insert a node before this node in the list of nodes in the graph.
Expand Down Expand Up @@ -1016,7 +1016,7 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
The index of the output of the node that produces the value can be accessed with
:meth:`index`.
To find all the nodes that use this value as an input, call :meth:`consumers`.
To find all the nodes that use this value as an input, call :meth:`uses`.
To check if the value is an output of a graph, call :meth:`is_graph_output`.
Expand All @@ -1036,7 +1036,7 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
"_shape",
"_type",
"_const_value",
"_consumers",
"_uses",
)

def __init__(
Expand All @@ -1063,10 +1063,10 @@ def __init__(
# TODO(justinchuby): Handle initialization when a const value is provided
# We can get shape and type information from the const value
self._const_value = const_value
# Use a collection of (Node, int) to store consumers. This is needed
# because a single consumer can use the same value multiple times.
# Use a collection of (Node, int) to store uses. This is needed
# because a single use can use the same value multiple times.
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
self._consumers: dict[tuple[Node, int], None] = {}
self._uses: dict[tuple[Node, int], None] = {}

def __repr__(self) -> str:
value_name = self.name if self.name else "anonymous:" + str(id(self))
Expand Down Expand Up @@ -1095,27 +1095,27 @@ def index(self) -> int | None:
"""The index of the output of the defining node."""
return self._index

def consumers(self) -> Collection[tuple[Node, int]]:
"""Return a set of consumers of the value.
def uses(self) -> Collection[tuple[Node, int]]:
"""Return a set of uses of the value.
The set contains tuples of ``(Node, index)`` where the index is the index of the input
of the node. For example, if ``node.inputs[1] == value``, then the consumer is ``(node, 1)``.
of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
"""
return self._consumers.keys()
return self._uses.keys()

def _add_consumer(self, consumer: Node, index: int) -> None:
"""Add a consumer node.
def _add_usage(self, use: Node, index: int) -> None:
"""Add a usage of this value.
This is an internal method. It should only be called by the Node class.
"""
self._consumers[(consumer, index)] = None
self._uses[(use, index)] = None

def _remove_consumer(self, consumer: Node, index: int) -> None:
"""Remove a node from the consumers of this value.
def _remove_usage(self, use: Node, index: int) -> None:
"""Remove a node from the uses of this value.
This is an internal method. It should only be called by the Node class.
"""
self._consumers.pop((consumer, index))
self._uses.pop((use, index))

@property
def name(self) -> str | None:
Expand Down Expand Up @@ -1246,7 +1246,7 @@ def _check_node_safe_to_remove(
to be removed before removing it.
2. It checks the node does not contribute to any graph outputs.
This check is typically O(1) assuming the number of consumers of the node is small
This check is typically O(1) assuming the number of uses of the node is small
Args:
node: The node to check.
Expand All @@ -1264,12 +1264,12 @@ def _check_node_safe_to_remove(
raise ValueError(
f"Node '{node!r}' is still an output of the graph and cannot be removed when safe=True."
)
for consumer, _ in output.consumers():
if consumer in to_remove:
for use, _ in output.uses():
if use in to_remove:
continue
raise ValueError(
f"Node '{consumer!r}' is still being used by other nodes that are not to be "
f"removed. All of its uses: {list(output.consumers())!r}"
f"Node '{use!r}' is still being used by other nodes that are not to be "
f"removed. All of its uses: {list(output.uses())!r}"
)


Expand Down
8 changes: 4 additions & 4 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ def test_remove_removes_node_from_graph(self):

def test_remove_does_not_change_input_users(self):
self.graph.remove(self.node)
self.assertEqual(tuple(self.v0.consumers()), ((self.node, 0),))
self.assertEqual(tuple(self.v1.consumers()), ((self.node, 1),))
self.assertEqual(tuple(self.v0.uses()), ((self.node, 0),))
self.assertEqual(tuple(self.v1.uses()), ((self.node, 1),))

def test_remove_does_not_change_graph_in_out(self):
self.graph.remove(self.node)
Expand Down Expand Up @@ -481,8 +481,8 @@ def test_remove_safe_removes_uses_of_removed_nodes(self):
identity_node.replace_input_with(0, sub_node.outputs[0])
graph.insert_before(identity_node, sub_node)
graph.remove(add_node, safe=True)
self.assertEqual(tuple(v0.consumers()), ((sub_node, 0),))
self.assertEqual(tuple(v1.consumers()), ((sub_node, 1),))
self.assertEqual(tuple(v0.uses()), ((sub_node, 0),))
self.assertEqual(tuple(v1.uses()), ((sub_node, 1),))
self.assertEqual(tuple(graph), (sub_node, identity_node))
self.assertEqual(add_node.inputs, (None, None))

Expand Down
4 changes: 2 additions & 2 deletions onnxscript/ir/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class ValueProtocol(Protocol):
The index of the output of the node that produces the value can be accessed with
:meth:`index`.
To find all the nodes that use this value as an input, call :meth:`consumers`.
To find all the nodes that use this value as an input, call :meth:`uses`.
To check if the value is an output of a graph, call :meth:`is_graph_output`.
Expand All @@ -163,7 +163,7 @@ def index(self) -> int | None:
"""The index of the output of the node that produces this value."""
...

def consumers(self) -> Collection[tuple[NodeProtocol, int]]:
def uses(self) -> Collection[tuple[NodeProtocol, int]]:
"""The set of (node, input_index) with node being those that use this value as an input."""
...

Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ def _match_forward(
return self.none(root_node, inspect.currentframe().f_lineno)

for o, op in zip(graph_node.outputs, pattern_node.outputs):
graph_node_users = [user for user, _ in o.consumers()]
pattern_node_users = [user for user, _ in op.consumers()]
graph_node_users = [user for user, _ in o.uses()]
pattern_node_users = [user for user, _ in op.uses()]
if not pattern_node_users:
# The pattern has no node forward, the matching stops.
continue
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def _valid_to_replace(matched_nodes: Sequence[Any]) -> bool:
if v.is_graph_output():
# value is an output-value of the graph/function.
return False
for consumer, _ in v.consumers():
for consumer, _ in v.uses():
if consumer not in matched_nodes:
return False
return True
Expand Down

0 comments on commit 8060e2d

Please sign in to comment.