Skip to content

Commit

Permalink
[IR] Implement methods to check dynamism on Shape (#1952)
Browse files Browse the repository at this point in the history
Define `is_static()` and `is_dynamic()` on Shape. Users can check if the
shape is static/dynamic, or if a specific axis is static/dynamic.

Fixes #1950
  • Loading branch information
justinchuby authored Nov 15, 2024
1 parent e6e3d52 commit 8c8417d
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 9 deletions.
67 changes: 58 additions & 9 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
Iterator,
OrderedDict,
Sequence,
SupportsInt,
Union,
)

import ml_dtypes
import numpy as np
from typing_extensions import TypeIs

import onnxscript
from onnxscript.ir import (
Expand Down Expand Up @@ -859,12 +861,37 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._value})"


def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
"""Return True if the value is int compatible."""
if isinstance(value, int):
return True
if hasattr(value, "__int__"):
# For performance reasons, we do not use isinstance(value, SupportsInt)
return True
return False


def _maybe_convert_to_symbolic_dim(
dim: int | SupportsInt | SymbolicDim | str | None,
) -> SymbolicDim | int:
"""Convert the value to a SymbolicDim if it is not an int."""
if dim is None or isinstance(dim, str):
return SymbolicDim(dim)
if _is_int_compatible(dim):
return int(dim)
if isinstance(dim, SymbolicDim):
return dim
raise TypeError(
f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
)


class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
__slots__ = ("_dims", "_frozen")

def __init__(
self,
dims: Iterable[int | SymbolicDim | str | None],
dims: Iterable[int | SupportsInt | SymbolicDim | str | None],
/,
denotations: Iterable[str | None] | None = None,
frozen: bool = False,
Expand All @@ -885,8 +912,7 @@ def __init__(
is useful when the shape is initialized by a Tensor.
"""
self._dims: list[int | SymbolicDim] = [
SymbolicDim(dim) if not isinstance(dim, (int, SymbolicDim)) else dim
for dim in dims
_maybe_convert_to_symbolic_dim(dim) for dim in dims
]
self._denotations: list[str | None] = (
list(denotations) if denotations is not None else [None] * len(self._dims)
Expand Down Expand Up @@ -946,12 +972,8 @@ def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None
"""
if self._frozen:
raise TypeError("The shape is frozen and cannot be modified.")
if isinstance(value, str) or value is None:
value = SymbolicDim(value)
if not isinstance(value, (int, SymbolicDim)):
raise TypeError(f"Expected int, str, None or SymbolicDim, got '{type(value)}'")

self._dims[index] = value
self._dims[index] = _maybe_convert_to_symbolic_dim(value)

def get_denotation(self, index: int) -> str | None:
"""Return the denotation of the dimension at the index.
Expand Down Expand Up @@ -986,7 +1008,7 @@ def __str__(self) -> str:
def __eq__(self, other: object) -> bool:
"""Return True if the shapes are equal.
Two shapes are eqaul if all their dimensions are equal.
Two shapes are equal if all their dimensions are equal.
"""
if isinstance(other, Shape):
return self._dims == other._dims
Expand All @@ -997,6 +1019,33 @@ def __eq__(self, other: object) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

@typing.overload
def is_static(self, dim: int) -> bool: # noqa: D418
"""Return True if the dimension is static."""

@typing.overload
def is_static(self) -> bool: # noqa: D418
"""Return True if all dimensions are static."""

def is_static(self, dim=None) -> bool:
"""Return True if the dimension is static. If dim is None, return True if all dimensions are static."""
if dim is None:
return all(isinstance(dim, int) for dim in self._dims)
return isinstance(self[dim], int)

@typing.overload
def is_dynamic(self, dim: int) -> bool: # noqa: D418
"""Return True if the dimension is dynamic."""

@typing.overload
def is_dynamic(self) -> bool: # noqa: D418
"""Return True if any dimension is dynamic."""

def is_dynamic(self, dim=None) -> bool:
if dim is None:
return not self.is_static()
return not self.is_static(dim)


def _quoted(string: str) -> str:
"""Return a quoted string.
Expand Down
78 changes: 78 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,30 @@ def test_int_dimensions_are_python_ints(self):
shape = _core.Shape([42])
self.assertIsInstance(shape[0], int)

def test_str_dimensions_are_symbolic_dims(self):
shape = _core.Shape(["any string"])
self.assertIsInstance(shape[0], _core.SymbolicDim)

def test_none_dimensions_are_symbolic_dims(self):
shape = _core.Shape([None])
self.assertIsInstance(shape[0], _core.SymbolicDim)

def test_init_raises_when_dims_is_not_a_list(self):
with self.assertRaises(TypeError):
_core.Shape(42)

def test_init_converts_np_shape_to_tuple(self):
dims = np.array([42, 42])
shape = _core.Shape(dims)
self.assertEqual(shape.dims, tuple(dims))

def test_init_converts_np_int_to_python_int(self):
dims = [np.int32(42)]
shape = _core.Shape(dims)
self.assertIsInstance(shape[0], int)
self.assertNotIsInstance(shape[0], np.int32)
self.assertIsInstance(shape.dims[0], int)

@parameterized.parameterized.expand(
[
("empty", (), ()),
Expand Down Expand Up @@ -623,6 +647,10 @@ def test_setitem(self, _: str, value):
else:
self.assertEqual(dim, value)

def test_len(self):
shape = _core.Shape([42, "any string"])
self.assertEqual(len(shape), 2)

def test_get_denotation(self):
shape = _core.Shape([42], denotations=("DATA_CHANNEL",))
self.assertEqual(shape.get_denotation(0), "DATA_CHANNEL")
Expand All @@ -637,6 +665,56 @@ def test_set_denotation_is_still_possible_when_shape_is_frozen(self):
shape.set_denotation(0, "UPDATED")
self.assertEqual(shape.get_denotation(0), "UPDATED")

def test_is_static(self):
dim_from_numpy = np.array([42]).shape[0]
np_int = np.int32(42)
shape = _core.Shape([42, "any string", dim_from_numpy, np_int])
self.assertTrue(shape.is_static(0))
self.assertFalse(shape.is_static(1))
self.assertTrue(shape.is_static(2))
self.assertTrue(shape.is_static(3))
self.assertFalse(shape.is_static())

def test_is_static_raises_when_index_out_of_range(self):
shape = _core.Shape([42])
with self.assertRaises(IndexError):
shape.is_static(1)

def test_is_static_on_whole_shape(self):
shape = _core.Shape([42, "any string"])
self.assertFalse(shape.is_static())
shape = _core.Shape([42, 42])
self.assertTrue(shape.is_static())

def test_is_static_on_empty_shape(self):
shape = _core.Shape(())
self.assertTrue(shape.is_static())

def test_is_dynamic(self):
dim_from_numpy = np.array([42]).shape[0]
np_int = np.int32(42)
shape = _core.Shape([42, "any string", dim_from_numpy, np_int])
self.assertFalse(shape.is_dynamic(0))
self.assertTrue(shape.is_dynamic(1))
self.assertFalse(shape.is_dynamic(2))
self.assertFalse(shape.is_dynamic(3))
self.assertTrue(shape.is_dynamic())

def test_is_dynamic_raises_when_index_out_of_range(self):
shape = _core.Shape([42])
with self.assertRaises(IndexError):
shape.is_dynamic(1)

def test_is_dynamic_on_whole_shape(self):
shape = _core.Shape([42, "any string"])
self.assertTrue(shape.is_dynamic())
shape = _core.Shape([42, 42])
self.assertFalse(shape.is_dynamic())

def test_is_dynamic_on_empty_shape(self):
shape = _core.Shape(())
self.assertFalse(shape.is_dynamic())


class ValueTest(unittest.TestCase):
def test_initialize(self):
Expand Down

0 comments on commit 8c8417d

Please sign in to comment.