Skip to content

Commit

Permalink
[Feature] allow tensorclass to be customized
Browse files Browse the repository at this point in the history
ghstack-source-id: 0b65b0a2dfb0cd7b5113e245c9444d3a0b55d085
Pull Request resolved: #1080
  • Loading branch information
vmoens committed Nov 7, 2024
1 parent 5125217 commit 4f794d6
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 9 deletions.
79 changes: 73 additions & 6 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,9 @@ def __init__(
@dataclass_transform()
def __call__(self, cls: T) -> T:
clz = _tensorclass(cls, frozen=self.frozen)
clz.autocast = self.autocast
clz.nocast = self.nocast
clz._autocast = self.autocast
clz._nocast = self.nocast
clz._frozen = self.frozen
return clz


Expand Down Expand Up @@ -1647,7 +1648,7 @@ def set_tensor(
def _is_castable(datatype):
return issubclass(datatype, (int, float, np.ndarray))

if cls.autocast:
if cls._autocast:
type_hints = cls._type_hints
if type_hints is not None:
target_cls = type_hints.get(key, _AnyType)
Expand Down Expand Up @@ -1693,7 +1694,7 @@ def _is_castable(datatype):
issubclass(value_type, torch.Tensor)
or _is_tensor_collection(value_type)
or (
not cls.nocast
not cls._nocast
and issubclass(value_type, (int, float, bool, np.ndarray))
)
):
Expand Down Expand Up @@ -3338,15 +3339,46 @@ def _update_shared_nontensor(nontensor, val):


class _TensorClassMeta(abc.ABCMeta):
def __new__(mcs, name, bases, namespace, **kwargs):
def __new__(
mcs, name, bases, namespace, autocast=None, nocast=None, frozen=None, **kwargs
):
# Create the class using the ABCMeta's __new__ method
cls = super().__new__(mcs, name, bases, namespace, **kwargs)

# Apply the dataclass decorator to the class
cls = _tensorclass(cls, frozen=False)
if frozen is None and hasattr(cls, "_frozen"):
frozen = cls._frozen
if nocast is None and hasattr(cls, "_nocast"):
nocast = cls._nocast
if autocast is None and hasattr(cls, "_autocast"):
autocast = cls._autocast

if name == "TensorClass" and "tensordict.tensorclass" in namespace.get(
"__module__", ""
):
pass
else:
cls = tensorclass(
frozen=bool(frozen), nocast=bool(nocast), autocast=bool(autocast)
)(cls)

return cls

def __getitem__(cls, item):
if not isinstance(item, tuple):
item = (item,)
name = "_".join(item)
cls_name = f"TensorClass_{name}"
bases = (cls,)
class_dict = {}
return cls.__class__.__new__(
cls.__class__,
cls_name,
bases,
class_dict,
**{_item: True for _item in item},
)


class TensorClass(metaclass=_TensorClassMeta):
"""TensorClass is the inheritance-based version of the @tensorclass decorator.
Expand All @@ -3372,6 +3404,41 @@ class TensorClass(metaclass=_TensorClassMeta):
device=None,
is_shared=False)
You can pass keyword arguments in two ways: using brackets or keyword arguments.
Examples:
>>> class Foo(TensorClass["autocast"]):
... integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass, autocast=True): # equivalent
... integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass["nocast"]):
... integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass["nocast", "frozen"]): # multiple keywords can be used
... integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass, nocast=True): # equivalent
... integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass):
... integer: int
>>> Foo(integer=1).integer
tensor(1)
.. warning:: TensorClass itself is not decorated as a tensorclass, but subclasses will be.
This is because we cannot anticipate if the frozen argument will be set, and if it is, it may
conflict with the parent class (a subclass cannot be frozen if the parent class isn't).
"""

_autocast: bool = False
_nocast: bool = False
_frozen: bool = False
...
3 changes: 3 additions & 0 deletions tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ T = TypeVar("T", bound="TensorDictBase")

@dataclasses.dataclass
class TensorClass:
_autocast: bool = False
_nocast: bool = False
_frozen: bool = False
def __init__(
self,
*args,
Expand Down
86 changes: 83 additions & 3 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest
import tensordict.utils
import torch
from tensordict import TensorClass

try:
import torchsnapshot
Expand Down Expand Up @@ -2217,19 +2218,19 @@ def test_autocast_attr(self):
class T:
X: torch.Tensor

assert not T.autocast
assert not T._autocast

@tensorclass
class T:
X: torch.Tensor

assert not T.autocast
assert not T._autocast

@tensorclass(autocast=True)
class T:
X: torch.Tensor

assert T.autocast
assert T._autocast

def test_autocast_simple(self):
obj = AutoCastTensor(
Expand Down Expand Up @@ -2500,6 +2501,85 @@ class X:
assert (x.div(2) == (x / 2)).all()


class TestSubClassing:
def test_subclassing(self):
class SubClass(TensorClass):
a: int

assert is_tensorclass(SubClass)
assert not SubClass._autocast
assert not SubClass._nocast
assert issubclass(SubClass, TensorClass)

def test_subclassing_autocast(self):
class SubClass(TensorClass, autocast=True):
a: int

assert is_tensorclass(SubClass)
assert SubClass._autocast
assert not SubClass._nocast
assert issubclass(SubClass, TensorClass)
assert isinstance(SubClass(torch.ones(())).a, int)

class SubClass(TensorClass["autocast"]):
a: int

assert not TensorClass._autocast
assert is_tensorclass(SubClass)
assert SubClass._autocast
assert not SubClass._nocast
assert issubclass(SubClass, TensorClass)
assert isinstance(SubClass(torch.ones(())).a, int)

def test_subclassing_nocast(self):
class SubClass(TensorClass, nocast=True):
a: int

assert is_tensorclass(SubClass)
assert not SubClass._autocast
assert SubClass._nocast
assert issubclass(SubClass, TensorClass)
assert isinstance(SubClass(1).a, int)

class SubClass(TensorClass["nocast"]):
a: int

assert not TensorClass._nocast
assert is_tensorclass(SubClass)
assert not SubClass._autocast
assert SubClass._nocast
assert issubclass(SubClass, TensorClass)
assert isinstance(SubClass(1).a, int)

def test_subclassing_mult(self):
class SubClass(TensorClass, nocast=True, frozen=True):
a: int

assert is_tensorclass(SubClass)
assert not SubClass._autocast
assert SubClass._nocast
assert SubClass._frozen
assert issubclass(SubClass, TensorClass)
s = SubClass(1)
assert isinstance(s.a, int)
with pytest.raises(RuntimeError):
s.a = 2

class SubClass(TensorClass["nocast", "frozen"]):
a: int

assert not TensorClass._nocast
assert not TensorClass._frozen
assert is_tensorclass(SubClass)
assert SubClass._nocast
assert SubClass._frozen
assert issubclass(SubClass, TensorClass)
s = SubClass(1)
assert isinstance(s.a, int)
with pytest.raises(RuntimeError):
s.a = 2


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit 4f794d6

Please sign in to comment.