From 2b6dc27b34f2e4e9fc4c3ad73635c5b157a4c714 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 May 2024 17:59:22 -0700 Subject: [PATCH] [IR] Improve name authority to generate unique names (#1537) - Store all names from the graph for generating unique names for new values. - Also allow values to be initialized with no arguments. Fix #1535 --- onnxscript/ir/_core.py | 41 +++++++++---- onnxscript/ir/_core_test.py | 12 ++-- onnxscript/ir/_name_authority.py | 59 +++++++++++++++---- onnxscript/ir/_name_authority_test.py | 26 ++++++++ onnxscript/ir/serde.py | 8 +-- onnxscript/ir/serde_test.py | 6 +- .../bfloat16_utils/bfloat16_converter.py | 4 -- 7 files changed, 115 insertions(+), 41 deletions(-) create mode 100644 onnxscript/ir/_name_authority_test.py diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 6f81598e1..2f42b8b9b 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1357,9 +1357,9 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable): def __init__( self, - producer: Node | None, + producer: Node | None = None, *, - index: int | None, + index: int | None = None, name: str | None = None, shape: Shape | None = None, type: _protocols.TypeProtocol | None = None, @@ -1368,7 +1368,18 @@ def __init__( | Sequence[_protocols.TensorProtocol] | None = None, ) -> None: - # producer is None when the value is an input or an initializer + """Initialize a value. + + Args: + producer: The node that produces the value. + It can be ``None`` when the value is initialized first than its producer. + index: The index of the output of the defining node. + name: The name of the value. + shape: The shape of the value. + type: The type of the value. + doc_string: The documentation string. + const_value: The constant tensor is the value constant. + """ self._producer: Node | None = producer self._index: int | None = index self._metadata: _metadata.MetadataStore | None = None @@ -1406,7 +1417,11 @@ def __str__(self) -> str: return f"%{_quoted(value_name)}<{type_text},{shape_text}>" def producer(self) -> Node | None: - """The node that produces this value.""" + """The node that produces this value. + + When producer is ``None``, the value does not belong to a node, and is + typically a graph input or an initializer. + """ return self._producer def index(self) -> int | None: @@ -1550,9 +1565,7 @@ def __init__( type: _protocols.TypeProtocol | None = None, doc_string: str | None = None, ) -> None: - super().__init__( - None, index=None, name=name, shape=shape, type=type, doc_string=doc_string - ) + super().__init__(name=name, shape=shape, type=type, doc_string=doc_string) def _check_node_safe_to_remove( @@ -1712,11 +1725,9 @@ def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node: f"The node '{node!r}' belongs to another graph. Please remove it first with Graph.remove()." ) # Give the node and its output values names if they don't not have one - if node.name is None: - self._name_authority.name_node(node) + self._name_authority.register_or_name_node(node) for value in node._outputs: # pylint: disable=protected-access - if value.name is None: - self._name_authority.name_value(value) + self._name_authority.register_or_name_value(value) node.graph = self return node @@ -1766,6 +1777,8 @@ def num_nodes(self) -> int: def append(self, node: Node, /) -> None: """Append a node to the graph in O(1) time. + Unique names will be assigned to the node and its values if any name is ``None``. + Args: node: The node to append. @@ -1778,6 +1791,8 @@ def append(self, node: Node, /) -> None: def extend(self, nodes: Iterable[Node], /) -> None: """Extend the graph with the given nodes in O(#new_nodes) time. + Unique names will be assigned to the node and its values if any name is ``None``. + Args: nodes: The nodes to extend the graph with. @@ -1830,6 +1845,8 @@ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes after the given node in O(#new_nodes) time. + Unique names will be assigned to the node and its values if any name is ``None``. + Args: node: The node to insert after. new_nodes: The new nodes to insert. @@ -1845,6 +1862,8 @@ def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes before the given node in O(#new_nodes) time. + Unique names will be assigned to the node and its values if any name is ``None``. + Args: node: The node to insert before. new_nodes: The new nodes to insert. diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 07c3301c0..e31d85187 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -554,10 +554,10 @@ def test_set_denotation_is_still_possible_when_shape_is_frozen(self): class ValueTest(unittest.TestCase): def test_initialize(self): - _ = _core.Value(None, index=0) + _ = _core.Value() def test_meta(self): - value = _core.Value(None, index=0) + value = _core.Value() value.meta["test"] = 1 self.assertEqual(value.meta["test"], 1) value.metadata_props["test"] = "any string" @@ -568,8 +568,8 @@ def test_meta(self): class NodeTest(unittest.TestCase): def setUp(self) -> None: - self.v0 = _core.Value(None, index=None) - self.v1 = _core.Value(None, index=None) + self.v0 = _core.Value() + self.v1 = _core.Value() self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3) def test_init_with_values(self): @@ -581,15 +581,11 @@ def test_init_with_values(self): def test_init_with_preinitialized_outputs(self): out_1 = _core.Value( - None, - index=None, name="out_1", shape=_core.Shape([1]), type=_core.TensorType(ir.DataType.BFLOAT16), ) out_2 = _core.Value( - None, - index=None, name="out_2", shape=_core.Shape([2]), type=_core.TensorType(ir.DataType.INT4), diff --git a/onnxscript/ir/_name_authority.py b/onnxscript/ir/_name_authority.py index 856c86247..895433564 100644 --- a/onnxscript/ir/_name_authority.py +++ b/onnxscript/ir/_name_authority.py @@ -12,20 +12,59 @@ class NameAuthority: ``node_{op_type}_{node_counter}`` for nodes. The counter is incremented each time a new value or node is named. - The class does not keep track of the names it has given, so it is possible to - generate names that conflicts with existing names. It is the responsibility of the - user to ensure that the names are unique (typically by running a name-fixing pass - on the graph). + This class keeps tracks of the names it has generated and existing names + in the graph to prevent producing duplicated names. + + .. note:: + Once a name is tracked, it will not be made available even if the node/value + is removed from the graph. It is possible to improve this behavior by keeping + track of the names that are no longer used, but it is not implemented yet. + + However, if a value/node is already named when added to the graph, + the name authority will not change its name. + It is the responsibility of the user to ensure that the names are unique + (typically by running a name-fixing pass on the graph). + + TODO(justichuby): Describe the pass when we have a reference implementation. """ def __init__(self): self._value_counter = 0 self._node_counter = 0 + self._value_names: set[str] = set() + self._node_names: set[str] = set() + + def _unique_value_name(self) -> str: + """Generate a unique name for a value.""" + while True: + name = f"val_{self._value_counter}" + self._value_counter += 1 + if name not in self._value_names: + return name + + def _unique_node_name(self, op_type: str) -> str: + """Generate a unique name for a node.""" + while True: + name = f"node_{op_type}_{self._node_counter}" + self._node_counter += 1 + if name not in self._node_names: + return name - def name_value(self, value: _core.Value) -> None: - value.name = f"val_{self._value_counter}" - self._value_counter += 1 + def register_or_name_value(self, value: _core.Value) -> None: + # TODO(justinchuby): Record names of the initializers and graph inputs + if value.name is None: + value.name = self._unique_value_name() + # If the name is already specified, we do not change it because keeping + # track of the used names can be costly when nodes can be removed from the graph: + # How do we know if a name is no longer used? We cannot reserve unused names + # because users may want to use them. + self._value_names.add(value.name) - def name_node(self, node: _core.Node) -> None: - node.name = f"node_{node.op_type}_{self._node_counter}" - self._node_counter += 1 + def register_or_name_node(self, node: _core.Node) -> None: + if node.name is None: + node.name = self._unique_node_name(node.op_type) + # If the name is already specified, we do not change it because keeping + # track of the used names can be costly when nodes can be removed from the graph: + # How do we know if a name is no longer used? We cannot reserve unused names + # because users may want to use them. + self._node_names.add(node.name) diff --git a/onnxscript/ir/_name_authority_test.py b/onnxscript/ir/_name_authority_test.py new file mode 100644 index 000000000..4bf7c6c7d --- /dev/null +++ b/onnxscript/ir/_name_authority_test.py @@ -0,0 +1,26 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import unittest + +from onnxscript import ir +from onnxscript.ir import _name_authority + + +class NameAuthorityTest(unittest.TestCase): + def test_register_or_name_value(self): + name_authority = _name_authority.NameAuthority() + value = ir.Value() + name_authority.register_or_name_value(value) + self.assertEqual(value.name, "val_0") + + def test_register_or_name_node(self): + name_authority = _name_authority.NameAuthority() + node = ir.Node("", "Test", []) + name_authority.register_or_name_node(node) + self.assertEqual(node.name, "node_Test_0") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 05093491d..d097e9a43 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -590,7 +590,7 @@ def deserialize_value_info_proto( proto: onnx.ValueInfoProto, value: _core.Value | None ) -> _core.Value: if value is None: - value = _core.Value(None, index=None, name=proto.name) + value = _core.Value(name=proto.name) value.shape = deserialize_type_proto_for_shape(proto.type) value.type = deserialize_type_proto_for_type(proto.type) metadata_props = deserialize_metadata_props(proto.metadata_props) @@ -847,7 +847,7 @@ def _deserialize_node( "the node is referencing a value that is not in the current graph, " "it is impossible to create it in the correct scope.", ) - value = _core.Value(None, index=None, name=input_name) + value = _core.Value(name=input_name) # Fill in shape/type information if they exist if input_name in value_info: deserialize_value_info_proto(value_info[input_name], value) @@ -862,7 +862,7 @@ def _deserialize_node( for output_name in proto.output: if output_name == "": # Empty output - node_outputs.append(_core.Value(None, index=None, name="")) + node_outputs.append(_core.Value(name="")) continue # 1. When the graph is unsorted, we may be able to find the output already created @@ -880,7 +880,7 @@ def _deserialize_node( else: # 2. Common scenario: the graph is sorted and this is the first time we see the output. # Create the value and add it to the current scope. - value = _core.Value(None, index=None, name=output_name) + value = _core.Value(name=output_name) current_scope[output_name] = value # Fill in shape/type information if they exist if output_name in value_info: diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py index b2f8ec07b..d06bf06f8 100644 --- a/onnxscript/ir/serde_test.py +++ b/onnxscript/ir/serde_test.py @@ -34,9 +34,7 @@ def test_from_proto(self, _: str, proto): ("graph", ir.Graph([], [], nodes=[])), ( "node", - ir.Node( - "", "Op", inputs=[], outputs=[ir.Value(None, index=None, name="value")] - ), + ir.Node("", "Op", inputs=[], outputs=[ir.Value(name="value")]), ), ( "tensor", @@ -44,7 +42,7 @@ def test_from_proto(self, _: str, proto): onnx.helper.make_tensor("test_tensor", onnx.TensorProto.FLOAT, [1], [1.0]) ), ), - ("value", ir.Value(None, index=None, name="value")), + ("value", ir.Value(name="value")), ("type", ir.SequenceType(ir.OptionalType(ir.TensorType(ir.DataType.COMPLEX128)))), ("attribute", ir.Attr("attribute", ir.AttributeType.FLOAT, 1)), ("ref_attribute", ir.RefAttr("ref_attr", "attr", ir.AttributeType.FLOAT)), diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py index e4afb432d..16d8838f7 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py @@ -3,7 +3,6 @@ from onnxscript import ir logger = logging.getLogger(__name__) -CREATED_CAST_BFLOAT16_NAME_SUFFIX = "_cast_bfloat16" def _convert_inputs_from_bfloat16_to_float16(value: ir.Input) -> None: @@ -61,9 +60,6 @@ def _insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value: ir.Value) -> No ) cast.outputs[0].dtype = ir.DataType.FLOAT16 cast.outputs[0].shape = node.outputs[index].shape - # To prevent naming conflicts, we need to append suffix to the output name of the cast node - # TODO: Remove this after naming authority covers this case - cast.outputs[0].name = node.outputs[index].name + CREATED_CAST_BFLOAT16_NAME_SUFFIX # type: ignore[operator] node.append(cast) assert node.graph is not None, "Node graph should not be None"