Skip to content

Commit

Permalink
[BugFix] Fix unitary ops for tensorclass
Browse files Browse the repository at this point in the history
ghstack-source-id: 2d117645769890b72f5856f68acbe1b48015cfbb
Pull Request resolved: #1164
  • Loading branch information
vmoens committed Jan 7, 2025
1 parent c0c6c14 commit 148c823
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,19 @@ def __subclasscheck__(self, subclass):
]

# Methods to be executed from tensordict, any ref to self means 'self._tensordict'
_FALLBACK_METHOD_FROM_TD_FORCE = [
"__ge__",
"__gt__",
"__le__",
"__lt__",
"__ror__",
]
_FALLBACK_METHOD_FROM_TD = [
"__abs__",
"__add__",
"__and__",
"__bool__",
"__eq__",
"__ge__",
"__gt__",
"__iadd__",
"__imul__",
"__invert__",
Expand All @@ -185,7 +190,6 @@ def __subclasscheck__(self, subclass):
"__radd__",
"__rand__",
"__rmul__",
"__ror__",
"__rpow__",
"__rsub__",
"__rtruediv__",
Expand Down Expand Up @@ -240,6 +244,7 @@ def __subclasscheck__(self, subclass):
"auto_batch_size_",
"auto_device_",
"bitwise_and",
"bool",
"ceil",
"ceil_",
"chunk",
Expand Down Expand Up @@ -814,6 +819,8 @@ def __torch_function__(
for method_name in _FALLBACK_METHOD_FROM_TD:
if not hasattr(cls, method_name):
setattr(cls, method_name, _wrap_td_method(method_name))
for method_name in _FALLBACK_METHOD_FROM_TD_FORCE:
setattr(cls, method_name, _wrap_td_method(method_name))
for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP:
if not hasattr(cls, method_name):
setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True))
Expand Down
41 changes: 41 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,47 @@ class X:
assert (x.mul(2) == (x * 2)).all()
assert (x.div(2) == (x / 2)).all()

def test_logic_and_right_ops(self):
@tensorclass
class MyClass:
x: str

c = MyClass(torch.randn(10))
_ = c < 0
_ = c > 0
_ = c <= 0
_ = c >= 0
_ = c != 0

_ = c.bool() ^ True
_ = True ^ c.bool()

_ = c.bool() | False
_ = False | c.bool()

_ = c.bool() & False
_ = False & c.bool()

_ = abs(c)

_ = c + 1
_ = 1 + c
c += 1

_ = c * 1
_ = 1 * c

_ = c - 1
_ = 1 - c
c -= 1

_ = c / 1
_ = 1 / c

_ = c**1
# not implemented
# 1 ** c


class TestSubClassing:
def test_subclassing(self):
Expand Down

0 comments on commit 148c823

Please sign in to comment.