Skip to content

Commit

Permalink
[IR] Support more attributes in convenience methods (#1506)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #1509
* #1508
* #1507
* __->__ #1506

Support graph and type attributes
  • Loading branch information
justinchuby authored May 7, 2024
1 parent 280fb39 commit c8cd684
Showing 1 changed file with 68 additions and 1 deletion.
69 changes: 68 additions & 1 deletion onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
onnx.TensorProto,
_core.Attr,
_core.RefAttr,
_protocols.GraphProtocol,
Sequence[_protocols.GraphProtocol],
_protocols.TypeProtocol,
Sequence[_protocols.TypeProtocol],
None,
]

Expand All @@ -56,6 +60,30 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
return _enums.AttributeType.TENSOR
if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)):
return _enums.AttributeType.GRAPH
if isinstance(attr, Sequence) and all(
isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr
):
return _enums.AttributeType.GRAPHS
if isinstance(
attr,
(_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
):
return _enums.AttributeType.TYPE_PROTO
if isinstance(attr, Sequence) and all(
isinstance(
x,
(
_core.TensorType,
_core.SequenceType,
_core.OptionalType,
_protocols.TypeProtocol,
),
)
for x in attr
):
return _enums.AttributeType.TYPE_PROTOS
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")


Expand Down Expand Up @@ -118,6 +146,14 @@ def convert_attribute(
return _core.AttrTensor(name, attr)
if isinstance(attr, onnx.TensorProto):
return _core.AttrTensor(name, serde.TensorProtoTensor(attr))
if attr_type == _enums.AttributeType.GRAPH:
return _core.AttrGraph(name, attr) # type: ignore[arg-type]
if attr_type == _enums.AttributeType.GRAPHS:
return _core.AttrGraphs(name, attr) # type: ignore[arg-type]
if attr_type == _enums.AttributeType.TYPE_PROTO:
return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
if attr_type == _enums.AttributeType.TYPE_PROTOS:
return _core.AttrTypeProtos(name, attr) # type: ignore[arg-type]
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")


Expand Down Expand Up @@ -148,9 +184,40 @@ def convert_attributes(
... float_data=[1.0, 2.0, 3.0],
... name="proto",
... ),
... "graph": ir.Graph([], [], nodes=[], name="graph0"),
... "graphs": [ir.Graph([], [], nodes=[], name="graph1"), ir.Graph([], [], nodes=[], name="graph2")],
... "type_proto": ir.TensorType(ir.DataType.FLOAT),
... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
... }
>>> convert_attributes(attrs)
[AttrInt64('int', 1), AttrFloat32('float', 1.0), AttrString('str', 'hello'), AttrInt64s('ints', [1, 2, 3]), AttrFloat32s('floats', [1.0, 2.0, 3.0]), AttrStrings('strings', ['hello', 'world']), AttrTensor('tensor', Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name='')), AttrTensor('tensor_proto', TensorProtoTensor<FLOAT,[3]>(name='proto'))]
[AttrInt64('int', 1), AttrFloat32('float', 1.0), AttrString('str', 'hello'), AttrInt64s('ints', [1, 2, 3]), AttrFloat32s('floats', [1.0, 2.0, 3.0]), AttrStrings('strings', ['hello', 'world']), AttrTensor('tensor', Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name='')), AttrTensor('tensor_proto', TensorProtoTensor<FLOAT,[3]>(name='proto')), AttrInt64s('graph', Graph(
name='graph0',
inputs=(
<BLANKLINE>
),
outputs=(
<BLANKLINE>
),
len()=0
)), AttrGraphs('graphs', [Graph(
name='graph1',
inputs=(
<BLANKLINE>
),
outputs=(
<BLANKLINE>
),
len()=0
), Graph(
name='graph2',
inputs=(
<BLANKLINE>
),
outputs=(
<BLANKLINE>
),
len()=0
)]), AttrTypeProto('type_proto', Tensor(FLOAT)), AttrTypeProtos('type_protos', [Tensor(FLOAT), Tensor(FLOAT)])]
Args:
attrs: A dictionary of {<attribute name>: <python objects>} to convert.
Expand Down

0 comments on commit c8cd684

Please sign in to comment.