Skip to content

Commit

Permalink
[IR] Improve name authority to generate unique names (#1537)
Browse files Browse the repository at this point in the history
- Store all names from the graph for generating unique names for new
values.
- Also allow values to be initialized with no arguments.

Fix #1535
  • Loading branch information
justinchuby authored May 15, 2024
1 parent ebee154 commit 2b6dc27
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 41 deletions.
41 changes: 30 additions & 11 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,9 +1357,9 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):

def __init__(
self,
producer: Node | None,
producer: Node | None = None,
*,
index: int | None,
index: int | None = None,
name: str | None = None,
shape: Shape | None = None,
type: _protocols.TypeProtocol | None = None,
Expand All @@ -1368,7 +1368,18 @@ def __init__(
| Sequence[_protocols.TensorProtocol]
| None = None,
) -> None:
# producer is None when the value is an input or an initializer
"""Initialize a value.
Args:
producer: The node that produces the value.
It can be ``None`` when the value is initialized first than its producer.
index: The index of the output of the defining node.
name: The name of the value.
shape: The shape of the value.
type: The type of the value.
doc_string: The documentation string.
const_value: The constant tensor is the value constant.
"""
self._producer: Node | None = producer
self._index: int | None = index
self._metadata: _metadata.MetadataStore | None = None
Expand Down Expand Up @@ -1406,7 +1417,11 @@ def __str__(self) -> str:
return f"%{_quoted(value_name)}<{type_text},{shape_text}>"

def producer(self) -> Node | None:
"""The node that produces this value."""
"""The node that produces this value.
When producer is ``None``, the value does not belong to a node, and is
typically a graph input or an initializer.
"""
return self._producer

def index(self) -> int | None:
Expand Down Expand Up @@ -1550,9 +1565,7 @@ def __init__(
type: _protocols.TypeProtocol | None = None,
doc_string: str | None = None,
) -> None:
super().__init__(
None, index=None, name=name, shape=shape, type=type, doc_string=doc_string
)
super().__init__(name=name, shape=shape, type=type, doc_string=doc_string)


def _check_node_safe_to_remove(
Expand Down Expand Up @@ -1712,11 +1725,9 @@ def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
f"The node '{node!r}' belongs to another graph. Please remove it first with Graph.remove()."
)
# Give the node and its output values names if they don't not have one
if node.name is None:
self._name_authority.name_node(node)
self._name_authority.register_or_name_node(node)
for value in node._outputs: # pylint: disable=protected-access
if value.name is None:
self._name_authority.name_value(value)
self._name_authority.register_or_name_value(value)
node.graph = self
return node

Expand Down Expand Up @@ -1766,6 +1777,8 @@ def num_nodes(self) -> int:
def append(self, node: Node, /) -> None:
"""Append a node to the graph in O(1) time.
Unique names will be assigned to the node and its values if any name is ``None``.
Args:
node: The node to append.
Expand All @@ -1778,6 +1791,8 @@ def append(self, node: Node, /) -> None:
def extend(self, nodes: Iterable[Node], /) -> None:
"""Extend the graph with the given nodes in O(#new_nodes) time.
Unique names will be assigned to the node and its values if any name is ``None``.
Args:
nodes: The nodes to extend the graph with.
Expand Down Expand Up @@ -1830,6 +1845,8 @@ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
"""Insert new nodes after the given node in O(#new_nodes) time.
Unique names will be assigned to the node and its values if any name is ``None``.
Args:
node: The node to insert after.
new_nodes: The new nodes to insert.
Expand All @@ -1845,6 +1862,8 @@ def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
"""Insert new nodes before the given node in O(#new_nodes) time.
Unique names will be assigned to the node and its values if any name is ``None``.
Args:
node: The node to insert before.
new_nodes: The new nodes to insert.
Expand Down
12 changes: 4 additions & 8 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,10 @@ def test_set_denotation_is_still_possible_when_shape_is_frozen(self):

class ValueTest(unittest.TestCase):
def test_initialize(self):
_ = _core.Value(None, index=0)
_ = _core.Value()

def test_meta(self):
value = _core.Value(None, index=0)
value = _core.Value()
value.meta["test"] = 1
self.assertEqual(value.meta["test"], 1)
value.metadata_props["test"] = "any string"
Expand All @@ -568,8 +568,8 @@ def test_meta(self):

class NodeTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Value(None, index=None)
self.v1 = _core.Value(None, index=None)
self.v0 = _core.Value()
self.v1 = _core.Value()
self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3)

def test_init_with_values(self):
Expand All @@ -581,15 +581,11 @@ def test_init_with_values(self):

def test_init_with_preinitialized_outputs(self):
out_1 = _core.Value(
None,
index=None,
name="out_1",
shape=_core.Shape([1]),
type=_core.TensorType(ir.DataType.BFLOAT16),
)
out_2 = _core.Value(
None,
index=None,
name="out_2",
shape=_core.Shape([2]),
type=_core.TensorType(ir.DataType.INT4),
Expand Down
59 changes: 49 additions & 10 deletions onnxscript/ir/_name_authority.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,59 @@ class NameAuthority:
``node_{op_type}_{node_counter}`` for nodes. The counter is incremented each time
a new value or node is named.
The class does not keep track of the names it has given, so it is possible to
generate names that conflicts with existing names. It is the responsibility of the
user to ensure that the names are unique (typically by running a name-fixing pass
on the graph).
This class keeps tracks of the names it has generated and existing names
in the graph to prevent producing duplicated names.
.. note::
Once a name is tracked, it will not be made available even if the node/value
is removed from the graph. It is possible to improve this behavior by keeping
track of the names that are no longer used, but it is not implemented yet.
However, if a value/node is already named when added to the graph,
the name authority will not change its name.
It is the responsibility of the user to ensure that the names are unique
(typically by running a name-fixing pass on the graph).
TODO(justichuby): Describe the pass when we have a reference implementation.
"""

def __init__(self):
self._value_counter = 0
self._node_counter = 0
self._value_names: set[str] = set()
self._node_names: set[str] = set()

def _unique_value_name(self) -> str:
"""Generate a unique name for a value."""
while True:
name = f"val_{self._value_counter}"
self._value_counter += 1
if name not in self._value_names:
return name

def _unique_node_name(self, op_type: str) -> str:
"""Generate a unique name for a node."""
while True:
name = f"node_{op_type}_{self._node_counter}"
self._node_counter += 1
if name not in self._node_names:
return name

def name_value(self, value: _core.Value) -> None:
value.name = f"val_{self._value_counter}"
self._value_counter += 1
def register_or_name_value(self, value: _core.Value) -> None:
# TODO(justinchuby): Record names of the initializers and graph inputs
if value.name is None:
value.name = self._unique_value_name()
# If the name is already specified, we do not change it because keeping
# track of the used names can be costly when nodes can be removed from the graph:
# How do we know if a name is no longer used? We cannot reserve unused names
# because users may want to use them.
self._value_names.add(value.name)

def name_node(self, node: _core.Node) -> None:
node.name = f"node_{node.op_type}_{self._node_counter}"
self._node_counter += 1
def register_or_name_node(self, node: _core.Node) -> None:
if node.name is None:
node.name = self._unique_node_name(node.op_type)
# If the name is already specified, we do not change it because keeping
# track of the used names can be costly when nodes can be removed from the graph:
# How do we know if a name is no longer used? We cannot reserve unused names
# because users may want to use them.
self._node_names.add(node.name)
26 changes: 26 additions & 0 deletions onnxscript/ir/_name_authority_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import unittest

from onnxscript import ir
from onnxscript.ir import _name_authority


class NameAuthorityTest(unittest.TestCase):
def test_register_or_name_value(self):
name_authority = _name_authority.NameAuthority()
value = ir.Value()
name_authority.register_or_name_value(value)
self.assertEqual(value.name, "val_0")

def test_register_or_name_node(self):
name_authority = _name_authority.NameAuthority()
node = ir.Node("", "Test", [])
name_authority.register_or_name_node(node)
self.assertEqual(node.name, "node_Test_0")


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def deserialize_value_info_proto(
proto: onnx.ValueInfoProto, value: _core.Value | None
) -> _core.Value:
if value is None:
value = _core.Value(None, index=None, name=proto.name)
value = _core.Value(name=proto.name)
value.shape = deserialize_type_proto_for_shape(proto.type)
value.type = deserialize_type_proto_for_type(proto.type)
metadata_props = deserialize_metadata_props(proto.metadata_props)
Expand Down Expand Up @@ -847,7 +847,7 @@ def _deserialize_node(
"the node is referencing a value that is not in the current graph, "
"it is impossible to create it in the correct scope.",
)
value = _core.Value(None, index=None, name=input_name)
value = _core.Value(name=input_name)
# Fill in shape/type information if they exist
if input_name in value_info:
deserialize_value_info_proto(value_info[input_name], value)
Expand All @@ -862,7 +862,7 @@ def _deserialize_node(
for output_name in proto.output:
if output_name == "":
# Empty output
node_outputs.append(_core.Value(None, index=None, name=""))
node_outputs.append(_core.Value(name=""))
continue

# 1. When the graph is unsorted, we may be able to find the output already created
Expand All @@ -880,7 +880,7 @@ def _deserialize_node(
else:
# 2. Common scenario: the graph is sorted and this is the first time we see the output.
# Create the value and add it to the current scope.
value = _core.Value(None, index=None, name=output_name)
value = _core.Value(name=output_name)
current_scope[output_name] = value
# Fill in shape/type information if they exist
if output_name in value_info:
Expand Down
6 changes: 2 additions & 4 deletions onnxscript/ir/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,15 @@ def test_from_proto(self, _: str, proto):
("graph", ir.Graph([], [], nodes=[])),
(
"node",
ir.Node(
"", "Op", inputs=[], outputs=[ir.Value(None, index=None, name="value")]
),
ir.Node("", "Op", inputs=[], outputs=[ir.Value(name="value")]),
),
(
"tensor",
serde.TensorProtoTensor(
onnx.helper.make_tensor("test_tensor", onnx.TensorProto.FLOAT, [1], [1.0])
),
),
("value", ir.Value(None, index=None, name="value")),
("value", ir.Value(name="value")),
("type", ir.SequenceType(ir.OptionalType(ir.TensorType(ir.DataType.COMPLEX128)))),
("attribute", ir.Attr("attribute", ir.AttributeType.FLOAT, 1)),
("ref_attribute", ir.RefAttr("ref_attr", "attr", ir.AttributeType.FLOAT)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from onnxscript import ir

logger = logging.getLogger(__name__)
CREATED_CAST_BFLOAT16_NAME_SUFFIX = "_cast_bfloat16"


def _convert_inputs_from_bfloat16_to_float16(value: ir.Input) -> None:
Expand Down Expand Up @@ -61,9 +60,6 @@ def _insert_cast_nodes_for_bfloat16_to_float16_to_outputs(value: ir.Value) -> No
)
cast.outputs[0].dtype = ir.DataType.FLOAT16
cast.outputs[0].shape = node.outputs[index].shape
# To prevent naming conflicts, we need to append suffix to the output name of the cast node
# TODO: Remove this after naming authority covers this case
cast.outputs[0].name = node.outputs[index].name + CREATED_CAST_BFLOAT16_NAME_SUFFIX # type: ignore[operator]
node.append(cast)

assert node.graph is not None, "Node graph should not be None"
Expand Down

0 comments on commit 2b6dc27

Please sign in to comment.