Skip to content

Commit

Permalink
[IR] Implement convert_attributes (#1417)
Browse files Browse the repository at this point in the history
Implement `convert_attribute` and `convert_attributes` as convenience
functions for the IR.

Also updates repr for Attr to make it more succinct.
  • Loading branch information
justinchuby authored Apr 22, 2024
1 parent 669a37e commit 5713872
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 69 deletions.
167 changes: 147 additions & 20 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,157 @@

from __future__ import annotations

from typing import Any, Mapping, Sequence
__all__ = [
"convert_attribute",
"convert_attributes",
"replace_all_uses_with",
]

from typing import Mapping, Sequence, Union

import onnx

from onnxscript.ir import _core, _enums, _protocols, serde

SupportedAttrTypes = Union[
str,
int,
float,
Sequence[int],
Sequence[float],
Sequence[str],
_protocols.TensorProtocol, # This includes all in-memory tensor types
onnx.TensorProto,
_core.Attr,
_core.RefAttr,
None,
]


def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
"""Infer the attribute type based on the type of the Python object."""
if isinstance(attr, int):
return _enums.AttributeType.INT
if isinstance(attr, float):
return _enums.AttributeType.FLOAT
if isinstance(attr, str):
return _enums.AttributeType.STRING
if isinstance(attr, (_core.Attr, _core.RefAttr)):
return attr.type
if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
return _enums.AttributeType.INTS
if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
return _enums.AttributeType.FLOATS
if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
return _enums.AttributeType.STRINGS
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
return _enums.AttributeType.TENSOR
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")


def convert_attribute(
name: str,
attr: SupportedAttrTypes,
attr_type: _enums.AttributeType | None = None,
) -> _core.Attr | _core.RefAttr:
"""Convert a Python object to a _core.Attr object.
This method is useful when constructing nodes with attributes. It infers the
attribute type based on the type of the Python value.
from onnxscript.ir import _core, _protocols
Args:
name: The name of the attribute.
attr: The value of the attribute.
attr_type: The type of the attribute. This is required when attr is None.
When provided, it overrides the inferred type.
Returns:
A ``Attr`` object.
Raises:
ValueError: If :param:`attr` is ``None`` and :param:`attr_type` is not provided.
TypeError: If the type of the attribute is not supported.
"""
if attr is None:
if attr_type is None:
raise ValueError("attr_type must be provided when attr is None")
return _core.Attr(name, attr_type, None)

if isinstance(attr, (_core.Attr, _core.RefAttr)):
if attr.name != name:
raise ValueError(
f"Attribute name '{attr.name}' does not match provided name '{name}'"
)
if attr_type is not None and attr.type != attr_type:
raise ValueError(
f"Attribute type '{attr.type}' does not match provided type '{attr_type}'"
)
return attr

if attr_type is None:
attr_type = _infer_attribute_type(attr)

if attr_type == _enums.AttributeType.INT:
return _core.AttrInt64(name, attr) # type: ignore
if attr_type == _enums.AttributeType.FLOAT:
return _core.AttrFloat32(name, attr) # type: ignore
if attr_type == _enums.AttributeType.STRING:
return _core.AttrString(name, attr) # type: ignore
if attr_type == _enums.AttributeType.INTS:
return _core.AttrInt64s(name, attr) # type: ignore
if attr_type == _enums.AttributeType.FLOATS:
return _core.AttrFloat32s(name, attr) # type: ignore
if attr_type == _enums.AttributeType.STRINGS:
return _core.AttrStrings(name, attr) # type: ignore
if attr_type == _enums.AttributeType.TENSOR:
if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
return _core.AttrTensor(name, attr)
if isinstance(attr, onnx.TensorProto):
return _core.AttrTensor(name, serde.TensorProtoTensor(attr))
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")


def convert_attributes(
attrs: Mapping[str, SupportedAttrTypes],
) -> list[_core.Attr | _core.RefAttr]:
"""Convert a dictionary of attributes to a list of _core.Attr objects.
It infers the attribute type based on the type of the value. The supported
types are: int, float, str, Sequence[int], Sequence[float], Sequence[str],
:class:`_core.Tensor`, and :class:`_core.Attr`::
>>> from onnxscript import ir
>>> import onnx
>>> import numpy as np
>>> attrs = {
... "int": 1,
... "float": 1.0,
... "str": "hello",
... "ints": [1, 2, 3],
... "floats": [1.0, 2.0, 3.0],
... "strings": ["hello", "world"],
... "tensor": ir.Tensor(np.array([1.0, 2.0, 3.0])),
... "tensor_proto":
... onnx.TensorProto(
... dims=[3],
... data_type=onnx.TensorProto.FLOAT,
... float_data=[1.0, 2.0, 3.0],
... name="proto",
... ),
... }
>>> convert_attributes(attrs)
[AttrInt64('int', 1), AttrFloat32('float', 1.0), AttrString('str', 'hello'), AttrInt64s('ints', [1, 2, 3]), AttrFloat32s('floats', [1.0, 2.0, 3.0]), AttrStrings('strings', ['hello', 'world']), AttrTensor('tensor', Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name='')), AttrTensor('tensor_proto', TensorProtoTensor<FLOAT,[3]>(name='proto'))]
Args:
attrs: A dictionary of {<attribute name>: <python objects>} to convert.
def convert_attributes(attrs: Mapping[str, Any]) -> list[_core.Attr]:
attributes: list[_core.Attr] = []
Returns:
A list of _core.Attr objects.
"""
attributes: list[_core.Attr | _core.RefAttr] = []
for name, attr in attrs.items():
if isinstance(attr, int):
attributes.append(_core.AttrInt64(name, attr))
elif isinstance(attr, float):
attributes.append(_core.AttrFloat32(name, attr))
elif isinstance(attr, str):
attributes.append(_core.AttrString(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
attributes.append(_core.AttrInt64s(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
attributes.append(_core.AttrFloat32s(name, attr))
elif isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
attributes.append(_core.AttrStrings(name, attr))
elif isinstance(attr, _core.Attr):
attributes.append(attr)
else:
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
attributes.append(convert_attribute(name, attr))
return attributes


Expand Down
35 changes: 20 additions & 15 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _printable_type_shape(self) -> str:
def _repr_base(self) -> str:
"""Base string for the repr method.
Example: Tensor<FLOAT:=1,5x42>
Example: Tensor<FLOAT,[5,42]>
"""
return f"{self.__class__.__name__}<{self._printable_type_shape()}>"

Expand Down Expand Up @@ -239,7 +239,7 @@ def __dlpack_device__(self) -> tuple[int, int]:
return self.__array__().__dlpack_device__()

def __repr__(self) -> str:
return f"{self._repr_base()}({self._raw!r})"
return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"

@property
def dtype(self) -> _enums.DataType:
Expand Down Expand Up @@ -2029,10 +2029,15 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"


class _SpecializedAttr(Attr):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r}, {self.value!r})"


# NOTE: The following classes are just supporting classes (partially applied) for convenience
# But I think they would be useful to have in the IR by having the type info
# explicitly in the class type.
class AttrFloat32(Attr):
class AttrFloat32(_SpecializedAttr):
def __init__(self, name: str, value: float, doc_string: str | None = None):
super().__init__(
name,
Expand All @@ -2042,7 +2047,7 @@ def __init__(self, name: str, value: float, doc_string: str | None = None):
)


class AttrInt64(Attr):
class AttrInt64(_SpecializedAttr):
def __init__(self, name: str, value: int, doc_string: str | None = None):
super().__init__(
name,
Expand All @@ -2052,7 +2057,7 @@ def __init__(self, name: str, value: int, doc_string: str | None = None):
)


class AttrString(Attr):
class AttrString(_SpecializedAttr):
def __init__(self, name: str, value: str, doc_string: str | None = None):
super().__init__(
name,
Expand All @@ -2062,7 +2067,7 @@ def __init__(self, name: str, value: str, doc_string: str | None = None):
)


class AttrTensor(Attr):
class AttrTensor(_SpecializedAttr):
def __init__(
self,
name: str,
Expand All @@ -2077,7 +2082,7 @@ def __init__(
)


class AttrGraph(Attr):
class AttrGraph(_SpecializedAttr):
def __init__(
self,
name: str,
Expand All @@ -2095,7 +2100,7 @@ def __str__(self) -> str:
return textwrap.indent("\n" + super().__str__(), " " * 4)


class AttrFloat32s(Attr):
class AttrFloat32s(_SpecializedAttr):
def __init__(
self,
name: str,
Expand All @@ -2110,7 +2115,7 @@ def __init__(
)


class AttrInt64s(Attr):
class AttrInt64s(_SpecializedAttr):
def __init__(
self,
name: str,
Expand All @@ -2125,7 +2130,7 @@ def __init__(
)


class AttrStrings(Attr):
class AttrStrings(_SpecializedAttr):
def __init__(
self,
name: str,
Expand All @@ -2140,7 +2145,7 @@ def __init__(
)


class AttrTensors(Attr):
class AttrTensors(_SpecializedAttr):
def __init__(
self,
name: str,
Expand All @@ -2155,7 +2160,7 @@ def __init__(
)


class AttrGraphs(Attr):
class AttrGraphs(_SpecializedAttr):
def __init__(
self,
name: str,
Expand All @@ -2171,7 +2176,7 @@ def __init__(


# NOTE: SparseTensor should be a sparse tensor proto
class AttrSparseTensor(Attr):
class AttrSparseTensor(_SpecializedAttr):
def __init__(
self,
name: str,
Expand All @@ -2186,7 +2191,7 @@ def __init__(
)


class AttrSparseTensors(Attr):
class AttrSparseTensors(_SpecializedAttr):
def __init__(
self,
name: str,
Expand All @@ -2201,7 +2206,7 @@ def __init__(
)


class AttrTypeProto(Attr):
class AttrTypeProto(_SpecializedAttr):
def __init__(
self,
name: str,
Expand Down
5 changes: 4 additions & 1 deletion onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def raw(self) -> onnx.TensorProto:
return self._proto

def __repr__(self) -> str:
return f"{self._repr_base()}({self.name!r})"
# It is a little hard to display the content when there can be types
# unsupported by numpy
# Preferably we should display some content when the tensor is small
return f"{self._repr_base()}(name={self.name!r})"

def __array__(self, dtype: Any = None) -> np.ndarray:
"""Return the tensor as a numpy array, compatible with np.array."""
Expand Down
Loading

0 comments on commit 5713872

Please sign in to comment.