Skip to content

Commit

Permalink
[IR] Implement to_proto and from_proto convenience functions (#1508)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #1508
  • Loading branch information
justinchuby authored May 10, 2024
1 parent b4dd777 commit fefea96
Show file tree
Hide file tree
Showing 7 changed files with 1,611 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# -- General configuration ---------------------------------------------------

extensions = [
"myst_parser",
"myst_nb",
"sphinx_copybutton",
"sphinx_exec_code",
"sphinx_gallery.gen_gallery",
Expand Down
1,487 changes: 1,487 additions & 0 deletions docs/intermediate_representation/getting_started.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/intermediate_representation/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
```{toctree}
:maxdepth: 1
getting_started
tensors
ir_api
```
5 changes: 4 additions & 1 deletion onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
"OperatorIdentifier",
# Protobuf compatible types
"TensorProtoTensor",
# Conversion functions
"from_proto",
"to_proto",
]

from onnxscript.ir import serde
Expand Down Expand Up @@ -126,4 +129,4 @@
TypeProtocol,
ValueProtocol,
)
from onnxscript.ir.serde import TensorProtoTensor
from onnxscript.ir.serde import TensorProtoTensor, from_proto, to_proto
71 changes: 71 additions & 0 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# Tensors
"TensorProtoTensor",
# Deserialization
"from_proto",
"deserialize_attribute",
"deserialize_function",
"deserialize_graph",
Expand All @@ -30,6 +31,7 @@
"deserialize_type_proto_for_type",
"deserialize_value_info_proto",
# Serialization
"to_proto",
"serialize_attribute_into",
"serialize_attribute",
"serialize_dimension_into",
Expand Down Expand Up @@ -89,6 +91,75 @@ def _unflatten_complex(
return array[::2] + 1j * array[1::2]


def from_proto(
proto: onnx.ModelProto
| onnx.GraphProto
| onnx.NodeProto
| onnx.TensorProto
| onnx.AttributeProto
| onnx.ValueInfoProto
| onnx.TypeProto,
) -> Any:
"""Deserialize an ONNX proto message to an IR object."""
if isinstance(proto, onnx.ModelProto):
return deserialize_model(proto)
if isinstance(proto, onnx.GraphProto):
return deserialize_graph(proto)
if isinstance(proto, onnx.NodeProto):
return deserialize_node(proto)
if isinstance(proto, onnx.TensorProto):
return deserialize_tensor(proto)
if isinstance(proto, onnx.AttributeProto):
return deserialize_attribute(proto)
if isinstance(proto, onnx.ValueInfoProto):
return deserialize_value_info_proto(proto, None)
if isinstance(proto, onnx.TypeProto):
return _core.TypeAndShape(
deserialize_type_proto_for_type(proto),
deserialize_type_proto_for_shape(proto),
)
raise NotImplementedError(
f"Deserialization of {type(proto)} in from_proto is not implemented. "
"Use a specific ir.serde.deserialize* function instead."
)


def to_proto(
ir_object: _protocols.ModelProtocol
| _protocols.GraphProtocol
| _protocols.NodeProtocol
| _protocols.ValueProtocol
| _protocols.AttributeProtocol
| _protocols.ReferenceAttributeProtocol
| _protocols.TensorProtocol
| onnx.TypeProto
| _protocols.GraphViewProtocol,
) -> Any:
"""Serialize an IR object to a proto."""
if isinstance(ir_object, _protocols.ModelProtocol):
return serialize_model(ir_object)
if isinstance(ir_object, _protocols.GraphProtocol):
return serialize_graph(ir_object)
if isinstance(ir_object, _protocols.NodeProtocol):
return serialize_node(ir_object)
if isinstance(ir_object, _protocols.TensorProtocol):
return serialize_tensor(ir_object)
if isinstance(ir_object, _protocols.ValueProtocol):
return serialize_value(ir_object)
if isinstance(ir_object, _protocols.AttributeProtocol):
return serialize_attribute(ir_object)
if isinstance(ir_object, _protocols.ReferenceAttributeProtocol):
return serialize_reference_attribute_into(onnx.AttributeProto(), ir_object)
if isinstance(ir_object, _protocols.TypeProtocol):
return serialize_type_into(onnx.TypeProto(), ir_object)
if isinstance(ir_object, _protocols.GraphViewProtocol):
return serialize_graph(ir_object)
raise NotImplementedError(
f"Serialization of {type(ir_object)} in to_proto is not implemented. "
"Use a specific ir.serde.serialize* function instead."
)


class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
"""A tensor initialized from a tensor proto."""

Expand Down
45 changes: 45 additions & 0 deletions onnxscript/ir/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,51 @@
from onnxscript.ir import serde


class ConvenienceFunctionsTest(unittest.TestCase):
@parameterized.parameterized.expand(
[
("model", onnx.ModelProto()),
("graph", onnx.GraphProto()),
("node", onnx.NodeProto()),
(
"tensor",
onnx.helper.make_tensor("test_tensor", onnx.TensorProto.FLOAT, [1], [1.0]),
),
("value_info", onnx.ValueInfoProto()),
("type", onnx.TypeProto()),
("attribute", onnx.AttributeProto()),
]
)
def test_from_proto(self, _: str, proto):
serde.from_proto(proto)

@parameterized.parameterized.expand(
[
("model", ir.Model(ir.Graph([], [], nodes=[]), ir_version=1)),
("graph", ir.Graph([], [], nodes=[])),
(
"node",
ir.Node(
"", "Op", inputs=[], outputs=[ir.Value(None, index=None, name="value")]
),
),
(
"tensor",
serde.TensorProtoTensor(
onnx.helper.make_tensor("test_tensor", onnx.TensorProto.FLOAT, [1], [1.0])
),
),
("value", ir.Value(None, index=None, name="value")),
("type", ir.SequenceType(ir.OptionalType(ir.TensorType(ir.DataType.COMPLEX128)))),
("attribute", ir.Attr("attribute", ir.AttributeType.FLOAT, 1)),
("ref_attribute", ir.RefAttr("ref_attr", "attr", ir.AttributeType.FLOAT)),
("graph_view", ir.GraphView([], [], nodes=[])),
]
)
def test_to_proto(self, _: str, ir_object):
serde.to_proto(ir_object)


class TensorProtoTensorTest(unittest.TestCase):
@parameterized.parameterized.expand(
[
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ sphinx-copybutton
sphinx-exec-code
sphinx-gallery
sphinx>=6
myst_nb
chardet

# Torch lib
beartype!=0.16.0
Expand Down

0 comments on commit fefea96

Please sign in to comment.