From c8cd68453d8052eb44f17bc201523fa6de8d071c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 7 May 2024 11:19:15 -0700 Subject: [PATCH] [IR] Support more attributes in convenience methods (#1506) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #1509 * #1508 * #1507 * __->__ #1506 Support graph and type attributes --- onnxscript/ir/_convenience.py | 69 ++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_convenience.py b/onnxscript/ir/_convenience.py index ccc10ef25..7eba1cb28 100644 --- a/onnxscript/ir/_convenience.py +++ b/onnxscript/ir/_convenience.py @@ -33,6 +33,10 @@ onnx.TensorProto, _core.Attr, _core.RefAttr, + _protocols.GraphProtocol, + Sequence[_protocols.GraphProtocol], + _protocols.TypeProtocol, + Sequence[_protocols.TypeProtocol], None, ] @@ -56,6 +60,30 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType: if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)): # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower return _enums.AttributeType.TENSOR + if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)): + return _enums.AttributeType.GRAPH + if isinstance(attr, Sequence) and all( + isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr + ): + return _enums.AttributeType.GRAPHS + if isinstance( + attr, + (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol), + ): + return _enums.AttributeType.TYPE_PROTO + if isinstance(attr, Sequence) and all( + isinstance( + x, + ( + _core.TensorType, + _core.SequenceType, + _core.OptionalType, + _protocols.TypeProtocol, + ), + ) + for x in attr + ): + return _enums.AttributeType.TYPE_PROTOS raise TypeError(f"Unsupported attribute type: '{type(attr)}'") @@ -118,6 +146,14 @@ def convert_attribute( return _core.AttrTensor(name, attr) if isinstance(attr, onnx.TensorProto): return _core.AttrTensor(name, serde.TensorProtoTensor(attr)) + if attr_type == _enums.AttributeType.GRAPH: + return _core.AttrGraph(name, attr) # type: ignore[arg-type] + if attr_type == _enums.AttributeType.GRAPHS: + return _core.AttrGraphs(name, attr) # type: ignore[arg-type] + if attr_type == _enums.AttributeType.TYPE_PROTO: + return _core.AttrTypeProto(name, attr) # type: ignore[arg-type] + if attr_type == _enums.AttributeType.TYPE_PROTOS: + return _core.AttrTypeProtos(name, attr) # type: ignore[arg-type] raise TypeError(f"Unsupported attribute type: '{type(attr)}'") @@ -148,9 +184,40 @@ def convert_attributes( ... float_data=[1.0, 2.0, 3.0], ... name="proto", ... ), + ... "graph": ir.Graph([], [], nodes=[], name="graph0"), + ... "graphs": [ir.Graph([], [], nodes=[], name="graph1"), ir.Graph([], [], nodes=[], name="graph2")], + ... "type_proto": ir.TensorType(ir.DataType.FLOAT), + ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)], ... } >>> convert_attributes(attrs) - [AttrInt64('int', 1), AttrFloat32('float', 1.0), AttrString('str', 'hello'), AttrInt64s('ints', [1, 2, 3]), AttrFloat32s('floats', [1.0, 2.0, 3.0]), AttrStrings('strings', ['hello', 'world']), AttrTensor('tensor', Tensor(array([1., 2., 3.]), name='')), AttrTensor('tensor_proto', TensorProtoTensor(name='proto'))] + [AttrInt64('int', 1), AttrFloat32('float', 1.0), AttrString('str', 'hello'), AttrInt64s('ints', [1, 2, 3]), AttrFloat32s('floats', [1.0, 2.0, 3.0]), AttrStrings('strings', ['hello', 'world']), AttrTensor('tensor', Tensor(array([1., 2., 3.]), name='')), AttrTensor('tensor_proto', TensorProtoTensor(name='proto')), AttrInt64s('graph', Graph( + name='graph0', + inputs=( + + ), + outputs=( + + ), + len()=0 + )), AttrGraphs('graphs', [Graph( + name='graph1', + inputs=( + + ), + outputs=( + + ), + len()=0 + ), Graph( + name='graph2', + inputs=( + + ), + outputs=( + + ), + len()=0 + )]), AttrTypeProto('type_proto', Tensor(FLOAT)), AttrTypeProtos('type_protos', [Tensor(FLOAT), Tensor(FLOAT)])] Args: attrs: A dictionary of {: } to convert.