Skip to content

Commit

Permalink
[IR] Make tensor types hashable (#1576)
Browse files Browse the repository at this point in the history
Make tensor types hashable so that it is easy to check types in a set of
accepted types during schema matching.

Test hashable properties on more classes in the IR.
  • Loading branch information
justinchuby authored May 29, 2024
1 parent b312348 commit 34e410a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
11 changes: 9 additions & 2 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Any,
Collection,
Generic,
Hashable,
Iterable,
Iterator,
OrderedDict,
Expand Down Expand Up @@ -1267,7 +1268,7 @@ def display(self, *, page: bool | None = None) -> None:
super().display(page=page)


class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable):
class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable):
"""Tensor types that are non recursive types."""

__slots__ = ("_dtype", "denotation")
Expand All @@ -1289,6 +1290,9 @@ def elem_type(self) -> _enums.DataType:
"""Return the element type of the tensor type"""
return self.dtype

def __hash__(self) -> int:
return hash(repr(self))

def __eq__(self, other: object) -> bool:
if self.__class__ is not other.__class__:
return False
Expand All @@ -1311,7 +1315,7 @@ class SparseTensorType(_TensorTypeBase):
"""A type that represents a sparse tensor."""


class _RecursiveTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable):
class _RecursiveTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable):
"""Base for recursive types like Optional and Sequence."""

__slots__ = ("_elem_type", "denotation")
Expand All @@ -1334,6 +1338,9 @@ def dtype(self, value: _enums.DataType) -> None:
def elem_type(self) -> _protocols.TypeProtocol:
return self._elem_type

def __hash__(self) -> int:
return hash(repr(self))

def __eq__(self, other: object) -> bool:
if not isinstance(other, _RecursiveTypeBase):
return False
Expand Down
65 changes: 65 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# --------------------------------------------------------------------------
from __future__ import annotations

import copy
import pathlib
import tempfile
import unittest
Expand Down Expand Up @@ -575,6 +576,11 @@ class ValueTest(unittest.TestCase):
def test_initialize(self):
_ = _core.Value()

def test_it_is_hashable(self):
value = _core.Value()
self.assertIsInstance(hash(value), int)
self.assertIn(value, {value})

def test_meta(self):
value = _core.Value()
value.meta["test"] = 1
Expand All @@ -591,6 +597,10 @@ def setUp(self) -> None:
self.v1 = _core.Value()
self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3)

def test_it_is_hashable(self):
self.assertIsInstance(hash(self.node), int)
self.assertIn(self.node, {self.node})

def test_init_with_values(self):
self.assertEqual(self.node.domain, "test")
self.assertEqual(self.node.op_type, "TestOp")
Expand Down Expand Up @@ -678,6 +688,10 @@ def test_initialize(self):
self.assertEqual(self.graph.initializers, {})
self.assertIsNone(self.graph.doc_string)

def test_it_is_hashable(self):
self.assertIsInstance(hash(self.graph), int)
self.assertIn(self.graph, {self.graph})

def test_it_is_iterable_of_nodes(self):
self.assertEqual(list(self.graph), [self.node])

Expand Down Expand Up @@ -767,5 +781,56 @@ def test_remove_safe_removes_uses_of_removed_nodes(self):
# TODO(justinchuby): Test graph mutation methods


class TypeTest(unittest.TestCase):
@parameterized.parameterized.expand(
[
("tensor", _core.TensorType(ir.DataType.FLOAT)),
("sequence", _core.SequenceType(_core.TensorType(ir.DataType.BOOL))),
("optional", _core.OptionalType(_core.TensorType(ir.DataType.FLOAT16))),
(
"sequence_optional",
_core.SequenceType(_core.OptionalType(_core.TensorType(ir.DataType.INT8))),
),
(
"optional_sequence",
_core.OptionalType(_core.SequenceType(_core.TensorType(ir.DataType.INT16))),
),
]
)
def test_type_is_hashable(self, _: str, type_: ir.TypeProtocol):
self.assertIsInstance(hash(type_), int)
self.assertIn(type_, {type_}) # type: ignore
# Assert that a different type object can still be matched
self.assertIn(copy.deepcopy(type_), {type_}) # type: ignore

def test_type_is_comparable(self):
self.assertEqual(
_core.TensorType(ir.DataType.FLOAT), _core.TensorType(ir.DataType.FLOAT)
)
self.assertNotEqual(
_core.TensorType(ir.DataType.FLOAT), _core.TensorType(ir.DataType.FLOAT16)
)

@parameterized.parameterized.expand(
[
("tensor", _core.TensorType(ir.DataType.FLOAT)),
("sequence", _core.SequenceType(_core.TensorType(ir.DataType.BOOL))),
("optional", _core.OptionalType(_core.TensorType(ir.DataType.FLOAT16))),
(
"sequence_optional",
_core.SequenceType(_core.OptionalType(_core.TensorType(ir.DataType.INT8))),
),
(
"optional_sequence",
_core.OptionalType(_core.SequenceType(_core.TensorType(ir.DataType.INT16))),
),
]
)
def test_composite_type_is_comparable(self, _: str, type_: ir.TypeProtocol):
self.assertEqual(type_, type_)
# Equal even if deep-copied
self.assertEqual(type_, copy.deepcopy(type_))


if __name__ == "__main__":
unittest.main()

0 comments on commit 34e410a

Please sign in to comment.