From 4f794d6710153b862465a83d4fd6c4a967531386 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 7 Nov 2024 13:04:57 +0000 Subject: [PATCH] [Feature] allow tensorclass to be customized ghstack-source-id: 0b65b0a2dfb0cd7b5113e245c9444d3a0b55d085 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1080 --- tensordict/tensorclass.py | 79 +++++++++++++++++++++++++++++++--- tensordict/tensorclass.pyi | 3 ++ test/test_tensorclass.py | 86 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 159 insertions(+), 9 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index fff0bb574..2b0ff493f 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -373,8 +373,9 @@ def __init__( @dataclass_transform() def __call__(self, cls: T) -> T: clz = _tensorclass(cls, frozen=self.frozen) - clz.autocast = self.autocast - clz.nocast = self.nocast + clz._autocast = self.autocast + clz._nocast = self.nocast + clz._frozen = self.frozen return clz @@ -1647,7 +1648,7 @@ def set_tensor( def _is_castable(datatype): return issubclass(datatype, (int, float, np.ndarray)) - if cls.autocast: + if cls._autocast: type_hints = cls._type_hints if type_hints is not None: target_cls = type_hints.get(key, _AnyType) @@ -1693,7 +1694,7 @@ def _is_castable(datatype): issubclass(value_type, torch.Tensor) or _is_tensor_collection(value_type) or ( - not cls.nocast + not cls._nocast and issubclass(value_type, (int, float, bool, np.ndarray)) ) ): @@ -3338,15 +3339,46 @@ def _update_shared_nontensor(nontensor, val): class _TensorClassMeta(abc.ABCMeta): - def __new__(mcs, name, bases, namespace, **kwargs): + def __new__( + mcs, name, bases, namespace, autocast=None, nocast=None, frozen=None, **kwargs + ): # Create the class using the ABCMeta's __new__ method cls = super().__new__(mcs, name, bases, namespace, **kwargs) # Apply the dataclass decorator to the class - cls = _tensorclass(cls, frozen=False) + if frozen is None and hasattr(cls, "_frozen"): + frozen = cls._frozen + if nocast is None and hasattr(cls, "_nocast"): + nocast = cls._nocast + if autocast is None and hasattr(cls, "_autocast"): + autocast = cls._autocast + + if name == "TensorClass" and "tensordict.tensorclass" in namespace.get( + "__module__", "" + ): + pass + else: + cls = tensorclass( + frozen=bool(frozen), nocast=bool(nocast), autocast=bool(autocast) + )(cls) return cls + def __getitem__(cls, item): + if not isinstance(item, tuple): + item = (item,) + name = "_".join(item) + cls_name = f"TensorClass_{name}" + bases = (cls,) + class_dict = {} + return cls.__class__.__new__( + cls.__class__, + cls_name, + bases, + class_dict, + **{_item: True for _item in item}, + ) + class TensorClass(metaclass=_TensorClassMeta): """TensorClass is the inheritance-based version of the @tensorclass decorator. @@ -3372,6 +3404,41 @@ class TensorClass(metaclass=_TensorClassMeta): device=None, is_shared=False) + You can pass keyword arguments in two ways: using brackets or keyword arguments. + + Examples: + >>> class Foo(TensorClass["autocast"]): + ... integer: int + >>> Foo(integer=torch.ones(())).integer + 1 + >>> class Foo(TensorClass, autocast=True): # equivalent + ... integer: int + >>> Foo(integer=torch.ones(())).integer + 1 + >>> class Foo(TensorClass["nocast"]): + ... integer: int + >>> Foo(integer=1).integer + 1 + >>> class Foo(TensorClass["nocast", "frozen"]): # multiple keywords can be used + ... integer: int + >>> Foo(integer=1).integer + 1 + >>> class Foo(TensorClass, nocast=True): # equivalent + ... integer: int + >>> Foo(integer=1).integer + 1 + >>> class Foo(TensorClass): + ... integer: int + >>> Foo(integer=1).integer + tensor(1) + + .. warning:: TensorClass itself is not decorated as a tensorclass, but subclasses will be. + This is because we cannot anticipate if the frozen argument will be set, and if it is, it may + conflict with the parent class (a subclass cannot be frozen if the parent class isn't). + """ + _autocast: bool = False + _nocast: bool = False + _frozen: bool = False ... diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 63cdaa181..75678b4b6 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -67,6 +67,9 @@ T = TypeVar("T", bound="TensorDictBase") @dataclasses.dataclass class TensorClass: + _autocast: bool = False + _nocast: bool = False + _frozen: bool = False def __init__( self, *args, diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 32d5dab91..adadd97ec 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -22,6 +22,7 @@ import pytest import tensordict.utils import torch +from tensordict import TensorClass try: import torchsnapshot @@ -2217,19 +2218,19 @@ def test_autocast_attr(self): class T: X: torch.Tensor - assert not T.autocast + assert not T._autocast @tensorclass class T: X: torch.Tensor - assert not T.autocast + assert not T._autocast @tensorclass(autocast=True) class T: X: torch.Tensor - assert T.autocast + assert T._autocast def test_autocast_simple(self): obj = AutoCastTensor( @@ -2500,6 +2501,85 @@ class X: assert (x.div(2) == (x / 2)).all() +class TestSubClassing: + def test_subclassing(self): + class SubClass(TensorClass): + a: int + + assert is_tensorclass(SubClass) + assert not SubClass._autocast + assert not SubClass._nocast + assert issubclass(SubClass, TensorClass) + + def test_subclassing_autocast(self): + class SubClass(TensorClass, autocast=True): + a: int + + assert is_tensorclass(SubClass) + assert SubClass._autocast + assert not SubClass._nocast + assert issubclass(SubClass, TensorClass) + assert isinstance(SubClass(torch.ones(())).a, int) + + class SubClass(TensorClass["autocast"]): + a: int + + assert not TensorClass._autocast + assert is_tensorclass(SubClass) + assert SubClass._autocast + assert not SubClass._nocast + assert issubclass(SubClass, TensorClass) + assert isinstance(SubClass(torch.ones(())).a, int) + + def test_subclassing_nocast(self): + class SubClass(TensorClass, nocast=True): + a: int + + assert is_tensorclass(SubClass) + assert not SubClass._autocast + assert SubClass._nocast + assert issubclass(SubClass, TensorClass) + assert isinstance(SubClass(1).a, int) + + class SubClass(TensorClass["nocast"]): + a: int + + assert not TensorClass._nocast + assert is_tensorclass(SubClass) + assert not SubClass._autocast + assert SubClass._nocast + assert issubclass(SubClass, TensorClass) + assert isinstance(SubClass(1).a, int) + + def test_subclassing_mult(self): + class SubClass(TensorClass, nocast=True, frozen=True): + a: int + + assert is_tensorclass(SubClass) + assert not SubClass._autocast + assert SubClass._nocast + assert SubClass._frozen + assert issubclass(SubClass, TensorClass) + s = SubClass(1) + assert isinstance(s.a, int) + with pytest.raises(RuntimeError): + s.a = 2 + + class SubClass(TensorClass["nocast", "frozen"]): + a: int + + assert not TensorClass._nocast + assert not TensorClass._frozen + assert is_tensorclass(SubClass) + assert SubClass._nocast + assert SubClass._frozen + assert issubclass(SubClass, TensorClass) + s = SubClass(1) + assert isinstance(s.a, int) + with pytest.raises(RuntimeError): + s.a = 2 + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)