Skip to content

Commit

Permalink
[IR] Support deserializing unsorted graphs (#1509)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #1510
* #1508
* __->__ #1509

When a graph is unsorted, we initialize the node input value first and
add it to the value pull. Producer nodes of this value will find it from
the pool when the node is initialized later, with a warning message.

```
Input 'val_1' of node 'node_1(::Node1:)' not found in any scope. The graph may be unsorted. Creating a new input (current depth: 1) .
```

We needed this feature because we also need to be able to convert
invalid models into IR and be also to fix them.

Fix #1426
  • Loading branch information
justinchuby authored May 7, 2024
1 parent 1e4b585 commit c2d1de1
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 36 deletions.
21 changes: 12 additions & 9 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,15 +1317,16 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
"""

__slots__ = (
"_producer",
"_const_value",
"_index",
"_metadata",
"_metadata_props",
"_metadata",
"_name",
"_producer",
"_shape",
"_type",
"_const_value",
"_uses",
"doc_string",
)

def __init__(
Expand All @@ -1336,6 +1337,7 @@ def __init__(
name: str | None = None,
shape: Shape | None = None,
type: _protocols.TypeProtocol | None = None,
doc_string: str | None = None,
const_value: _protocols.TensorProtocol
| Sequence[_protocols.TensorProtocol]
| None = None,
Expand All @@ -1356,19 +1358,20 @@ def __init__(
# because a single use can use the same value multiple times.
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
self._uses: dict[tuple[Node, int], None] = {}
self.doc_string = doc_string

def __repr__(self) -> str:
value_name = self.name if self.name else "anonymous:" + str(id(self))
producer = self.producer()
producer_text = (
producer.name or "anonymous_node:" + str(id(producer))
producer.name is not None or "anonymous_node:" + str(id(producer))
if producer is not None
else None
)
return f"{self.__class__.__name__}({value_name!r}, type={self.type!r}, shape={self.shape}, producer={producer_text}, index={self.index()})"

def __str__(self) -> str:
value_name = self.name if self.name else "anonymous:" + str(id(self))
value_name = self.name if self.name is not None else "anonymous:" + str(id(self))
shape_text = str(self.shape) if self.shape is not None else "?"
type_text = str(self.type) if self.type is not None else "?"

Expand Down Expand Up @@ -1519,11 +1522,11 @@ def __init__(
name: str | None = None,
shape: Shape | None = None,
type: _protocols.TypeProtocol | None = None,
doc_string: str | None = None,
) -> None:
super().__init__(None, index=None)
self._name = name
self._shape = shape
self._type = type
super().__init__(
None, index=None, name=name, shape=shape, type=type, doc_string=doc_string
)


def _check_node_safe_to_remove(
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/ir/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,15 @@ class ValueProtocol(Protocol):
type: The type of the value.
metadata_props: Metadata that will be serialized to the ONNX file.
meta: Metadata store for graph transform passes.
doc_string: Documentation string.
"""

name: str
shape: ShapeProtocol | None
type: TypeProtocol | None
metadata_props: MutableMapping[str, str]
meta: MutableMapping[str, Any]
doc_string: str | None

def producer(self) -> NodeProtocol | None:
"""The node that produces this value."""
Expand Down
105 changes: 78 additions & 27 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ def _deserialize_graph(
value_info = {info.name: info for info in proto.value_info}

# Deserialize nodes with all known values
# TODO(justinchuby): Handle unsorted nodes
nodes = [_deserialize_node(node, scoped_values, value_info) for node in proto.node]

# Fill in values for graph outputs
Expand Down Expand Up @@ -514,13 +513,13 @@ def deserialize_value_info_proto(
proto: onnx.ValueInfoProto, value: _core.Value | None
) -> _core.Value:
if value is None:
value = _core.Value(None, index=None)
value.name = proto.name
value = _core.Value(None, index=None, name=proto.name)
value.shape = deserialize_type_proto_for_shape(proto.type)
value.type = deserialize_type_proto_for_type(proto.type)
metadata_props = deserialize_metadata_props(proto.metadata_props)
if metadata_props is not None:
value.metadata_props.update(metadata_props)
value.doc_string = _get_field(proto, "doc_string")
return value


Expand Down Expand Up @@ -735,50 +734,100 @@ def _deserialize_node(
value_info: dict[str, onnx.ValueInfoProto],
) -> _core.Node:
node_inputs: list[_core.Value | None] = []
for name in proto.input:
if name == "":
for input_name in proto.input:
if input_name == "":
# Empty input
node_inputs.append(None)
continue

# Find the input in all value scopes
found = False
for values in reversed(scoped_values):
if name not in values:
if input_name not in values:
continue
node_inputs.append(values[name])
node_inputs.append(values[input_name])
found = True
del values # Remove the reference so it is not used by mistake
break
if not found:
raise ValueError(
f"Input '{name}' of node '{proto.name}({proto.domain}::{proto.op_type}:{getattr(proto, 'overload', '')})' not found in any scope"
f" (current depth: {len(scoped_values)})"
# If the input is not found, we know the graph may be unsorted and
# the input may be a supposed-to-be initializer or an output of a node that comes later.
# Here we create the value with the name and add it to the current scope.
# Nodes need to check the value pool for potentially initialized outputs
logger.warning(
"Input '%s' of node '%s(%s::%s:%s)' not found in any scope. "
"The graph may be unsorted. Creating a new input (current depth: %s) .",
input_name,
proto.name,
proto.domain,
proto.op_type,
getattr(proto, "overload", ""),
len(scoped_values),
)
if len(scoped_values) > 1:
logger.warning(
"Caveat: The value is created in the subgraph. If "
"the node is referencing a value that is not in the current graph, "
"it is impossible to create it in the correct scope.",
)
value = _core.Value(None, index=None, name=input_name)
# Fill in shape/type information if they exist
if input_name in value_info:
deserialize_value_info_proto(value_info[input_name], value)
node_inputs.append(value)
# We can only create the value in the current scope. If the subgraph is
# referencing a value that is not in the current scope, it is impossible
# to create it in the correct scope.
scoped_values[-1][input_name] = value

# Build the output values for the node.
node_outputs: list[_core.Value] = []
for output_name in proto.output:
if output_name == "":
# Empty output
node_outputs.append(_core.Value(None, index=None, name=""))
continue

# 1. When the graph is unsorted, we may be able to find the output already created
# as an input to some other nodes in the current scope.
# Note that a value is always owned by the producing node. Even though a value
# can be created when parsing inputs of other nodes, the new node created here
# that produces the value will assume ownership. It is then impossible to transfer
# the ownership to any other node.

# The output can only be found in the current scope. It is impossible for
# a node to produce an output that is not in its own scope.
current_scope = scoped_values[-1]
if output_name in current_scope:
value = current_scope[output_name]
else:
# 2. Common scenario: the graph is sorted and this is the first time we see the output.
# Create the value and add it to the current scope.
value = _core.Value(None, index=None, name=output_name)
current_scope[output_name] = value
# Fill in shape/type information if they exist
if output_name in value_info:
deserialize_value_info_proto(value_info[output_name], value)
else:
logger.debug(
"ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
output_name,
proto.name,
proto.op_type,
)
node = _core.Node(
node_outputs.append(value)
return _core.Node(
proto.domain,
proto.op_type,
node_inputs,
[_deserialize_attribute(a, scoped_values) for a in proto.attribute],
overload=getattr(proto, "overload", ""),
num_outputs=len(proto.output),
outputs=node_outputs,
name=proto.name,
doc_string=_get_field(proto, "doc_string"),
metadata_props=deserialize_metadata_props(proto.metadata_props),
)

for output, value in zip(proto.output, node.outputs):
value.name = output
if output in value_info:
deserialize_value_info_proto(value_info[output], value)
else:
logger.debug(
"ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
output,
proto.name,
proto.op_type,
)
scoped_values[-1][output] = value

return node


# Serialization

Expand Down Expand Up @@ -1248,6 +1297,8 @@ def serialize_value_into(
# Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
if from_.shape is not None:
serialize_shape_into(value_info_proto.type, from_.shape)
if from_.doc_string:
value_info_proto.doc_string = from_.doc_string


def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None:
Expand Down
34 changes: 34 additions & 0 deletions onnxscript/ir/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import onnx
import parameterized

from onnxscript import ir
from onnxscript.ir import serde


Expand Down Expand Up @@ -175,3 +176,36 @@ def test_tensor_proto_tensor_empty_tensor(self):
)
array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data)
np.testing.assert_array_equal(array_from_raw_data, expected_array)


class DeserializeGraphTest(unittest.TestCase):
def test_deserialize_graph_handles_unsorted_graph(self):
node_0 = ir.Node(
"",
"Op_0",
inputs=[ir.Input("input_0"), ir.Input("input_1")],
num_outputs=2,
name="node_0",
)
node_1 = ir.Node(
"",
"Op_1",
inputs=[node_0.outputs[0]],
num_outputs=1,
name="node_1",
)
graph = ir.Graph(
inputs=node_0.inputs, # type: ignore
outputs=[node_1.outputs[0]],
# Unsorted nodes
nodes=[node_1, node_0],
name="test_graph",
)
graph_proto = serde.serialize_graph(graph)
deserialized_graph = serde.deserialize_graph(graph_proto)
self.assertEqual(deserialized_graph[0].op_type, "Op_1")
self.assertEqual(deserialized_graph[1].op_type, "Op_0")


if __name__ == "__main__":
unittest.main()

0 comments on commit c2d1de1

Please sign in to comment.