diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py index 255b3e5a7..ccc10ef25 100644 --- a/onnxscript/ir/_convenience.py +++ b/onnxscript/ir/_convenience.py @@ -168,7 +168,7 @@ def replace_all_uses_with( values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol], replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol], ) -> None: - """Replace all consumers of the given values with the replacements. + """Replace all uses of the given values with the replacements. This is useful when nodes in the graph are replaced with new nodes, where the old users need to be updated to use the outputs of the new nodes. @@ -194,7 +194,7 @@ def replace_all_uses_with( 1 >>> node_c.inputs[0].producer().op_type 'D' - >>> len(node_a.outputs[0].consumers()) + >>> len(node_a.outputs[0].uses()) 0 When values and replacements are sequences, they are zipped into pairs. All @@ -216,5 +216,5 @@ def replace_all_uses_with( if len(values) != len(replacements): raise ValueError("The number of values and replacements must match.") for value, replacement in zip(values, replacements): - for user_node, index in tuple(value.consumers()): + for user_node, index in tuple(value.uses()): user_node.replace_input_with(index, replacement) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index e149b423e..06549df47 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -635,11 +635,11 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable): user is responsible to call ``graph.append(node)`` (or other mutation methods in :class:`Graph`) to add the node to the graph. - After the node is initialized, it will add itself as a consumer of the input values. + After the node is initialized, it will add itself as a user of the input values. The output values of the node are created during node initialization and are immutable. - To change the output values, create a new node and replace the each of the inputs of ``output.consumers`` with - the new output values by calling :meth:`replace_input_with` on the consumer nodes + To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with + the new output values by calling :meth:`replace_input_with` on the using nodes of this node's outputs. """ @@ -673,7 +673,7 @@ def __init__( doc_string: str | None = None, metadata_props: dict[str, str] | None = None, ): - """Initialize a node and add it as a consumer of the input values. + """Initialize a node and add it as a user of the input values. Args: domain: The domain of the operator. For onnx operators, this is an empty string. @@ -718,10 +718,10 @@ def __init__( self._graph: Graph | None = graph self.doc_string = doc_string - # Add the node as a consumer of the inputs + # Add the node as a use of the inputs for i, input_value in enumerate(self._inputs): if input_value is not None: - input_value._add_consumer(self, i) # pylint: disable=protected-access + input_value._add_usage(self, i) # pylint: disable=protected-access # Add the node to the graph if graph is specified if self._graph is not None: @@ -821,9 +821,9 @@ def replace_input_with(self, index: int, value: Value | None) -> None: value if i == index else old_input for i, old_input in enumerate(self.inputs) ) if old_input is not None: - old_input._remove_consumer(self, index) # pylint: disable=protected-access + old_input._remove_usage(self, index) # pylint: disable=protected-access if value is not None: - value._add_consumer(self, index) # pylint: disable=protected-access + value._add_usage(self, index) # pylint: disable=protected-access def prepend(self, /, nodes: Node | Iterable[Node]) -> None: """Insert a node before this node in the list of nodes in the graph. @@ -1016,7 +1016,7 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable): The index of the output of the node that produces the value can be accessed with :meth:`index`. - To find all the nodes that use this value as an input, call :meth:`consumers`. + To find all the nodes that use this value as an input, call :meth:`uses`. To check if the value is an output of a graph, call :meth:`is_graph_output`. @@ -1036,7 +1036,7 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable): "_shape", "_type", "_const_value", - "_consumers", + "_uses", ) def __init__( @@ -1063,10 +1063,10 @@ def __init__( # TODO(justinchuby): Handle initialization when a const value is provided # We can get shape and type information from the const value self._const_value = const_value - # Use a collection of (Node, int) to store consumers. This is needed - # because a single consumer can use the same value multiple times. + # Use a collection of (Node, int) to store uses. This is needed + # because a single use can use the same value multiple times. # Use a dictionary to preserve insertion order so that the visiting order is deterministic - self._consumers: dict[tuple[Node, int], None] = {} + self._uses: dict[tuple[Node, int], None] = {} def __repr__(self) -> str: value_name = self.name if self.name else "anonymous:" + str(id(self)) @@ -1095,27 +1095,27 @@ def index(self) -> int | None: """The index of the output of the defining node.""" return self._index - def consumers(self) -> Collection[tuple[Node, int]]: - """Return a set of consumers of the value. + def uses(self) -> Collection[tuple[Node, int]]: + """Return a set of uses of the value. The set contains tuples of ``(Node, index)`` where the index is the index of the input - of the node. For example, if ``node.inputs[1] == value``, then the consumer is ``(node, 1)``. + of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``. """ - return self._consumers.keys() + return self._uses.keys() - def _add_consumer(self, consumer: Node, index: int) -> None: - """Add a consumer node. + def _add_usage(self, use: Node, index: int) -> None: + """Add a usage of this value. This is an internal method. It should only be called by the Node class. """ - self._consumers[(consumer, index)] = None + self._uses[(use, index)] = None - def _remove_consumer(self, consumer: Node, index: int) -> None: - """Remove a node from the consumers of this value. + def _remove_usage(self, use: Node, index: int) -> None: + """Remove a node from the uses of this value. This is an internal method. It should only be called by the Node class. """ - self._consumers.pop((consumer, index)) + self._uses.pop((use, index)) @property def name(self) -> str | None: @@ -1246,7 +1246,7 @@ def _check_node_safe_to_remove( to be removed before removing it. 2. It checks the node does not contribute to any graph outputs. - This check is typically O(1) assuming the number of consumers of the node is small + This check is typically O(1) assuming the number of uses of the node is small Args: node: The node to check. @@ -1264,12 +1264,12 @@ def _check_node_safe_to_remove( raise ValueError( f"Node '{node!r}' is still an output of the graph and cannot be removed when safe=True." ) - for consumer, _ in output.consumers(): - if consumer in to_remove: + for use, _ in output.uses(): + if use in to_remove: continue raise ValueError( - f"Node '{consumer!r}' is still being used by other nodes that are not to be " - f"removed. All of its uses: {list(output.consumers())!r}" + f"Node '{use!r}' is still being used by other nodes that are not to be " + f"removed. All of its uses: {list(output.uses())!r}" ) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 51cd4e2a6..6746d81b1 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -434,8 +434,8 @@ def test_remove_removes_node_from_graph(self): def test_remove_does_not_change_input_users(self): self.graph.remove(self.node) - self.assertEqual(tuple(self.v0.consumers()), ((self.node, 0),)) - self.assertEqual(tuple(self.v1.consumers()), ((self.node, 1),)) + self.assertEqual(tuple(self.v0.uses()), ((self.node, 0),)) + self.assertEqual(tuple(self.v1.uses()), ((self.node, 1),)) def test_remove_does_not_change_graph_in_out(self): self.graph.remove(self.node) @@ -481,8 +481,8 @@ def test_remove_safe_removes_uses_of_removed_nodes(self): identity_node.replace_input_with(0, sub_node.outputs[0]) graph.insert_before(identity_node, sub_node) graph.remove(add_node, safe=True) - self.assertEqual(tuple(v0.consumers()), ((sub_node, 0),)) - self.assertEqual(tuple(v1.consumers()), ((sub_node, 1),)) + self.assertEqual(tuple(v0.uses()), ((sub_node, 0),)) + self.assertEqual(tuple(v1.uses()), ((sub_node, 1),)) self.assertEqual(tuple(graph), (sub_node, identity_node)) self.assertEqual(add_node.inputs, (None, None)) diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index d920af9e4..44b521495 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -137,7 +137,7 @@ class ValueProtocol(Protocol): The index of the output of the node that produces the value can be accessed with :meth:`index`. - To find all the nodes that use this value as an input, call :meth:`consumers`. + To find all the nodes that use this value as an input, call :meth:`uses`. To check if the value is an output of a graph, call :meth:`is_graph_output`. @@ -163,7 +163,7 @@ def index(self) -> int | None: """The index of the output of the node that produces this value.""" ... - def consumers(self) -> Collection[tuple[NodeProtocol, int]]: + def uses(self) -> Collection[tuple[NodeProtocol, int]]: """The set of (node, input_index) with node being those that use this value as an input.""" ... diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index 3adbc58ca..6cc70da64 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -505,8 +505,8 @@ def _match_forward( return self.none(root_node, inspect.currentframe().f_lineno) for o, op in zip(graph_node.outputs, pattern_node.outputs): - graph_node_users = [user for user, _ in o.consumers()] - pattern_node_users = [user for user, _ in op.consumers()] + graph_node_users = [user for user, _ in o.uses()] + pattern_node_users = [user for user, _ in op.uses()] if not pattern_node_users: # The pattern has no node forward, the matching stops. continue diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 6fcf76a7b..050851932 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -780,7 +780,7 @@ def _valid_to_replace(matched_nodes: Sequence[Any]) -> bool: if v.is_graph_output(): # value is an output-value of the graph/function. return False - for consumer, _ in v.consumers(): + for consumer, _ in v.uses(): if consumer not in matched_nodes: return False return True