From 778584211dff55050db9b884c0675b22a45bb5fc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 7 Nov 2024 10:41:02 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/tensorclass.py | 60 +++++++++++++++++++++++++++++---------- test/test_tensorclass.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 15 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 213406bd1..fff0bb574 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -351,28 +351,41 @@ def is_non_tensor(obj): class _tensorclass_dec: - def __new__(cls, autocast: bool = False, frozen: bool = False): + def __new__( + cls, autocast: bool = False, frozen: bool = False, nocast: bool = False + ): if not isinstance(autocast, bool): clz = autocast self = super().__new__(cls) - self.__init__(autocast=False, frozen=False) + self.__init__(autocast=False, frozen=False, nocast=False) return self.__call__(clz) return super().__new__(cls) - def __init__(self, autocast: bool = False, frozen: bool = False): + def __init__( + self, autocast: bool = False, frozen: bool = False, nocast: bool = False + ): + if autocast and nocast: + raise ValueError("autocast is exclusive with nocast.") self.autocast = autocast self.frozen = frozen + self.nocast = nocast @dataclass_transform() def __call__(self, cls: T) -> T: clz = _tensorclass(cls, frozen=self.frozen) clz.autocast = self.autocast + clz.nocast = self.nocast return clz @dataclass_transform() def tensorclass( - cls: T = None, /, *, autocast: bool = False, frozen: bool = False + cls: T = None, + /, + *, + autocast: bool = False, + frozen: bool = False, + nocast: bool = False, ) -> T | None: """A decorator to create :obj:`tensorclass` classes. @@ -381,30 +394,44 @@ def tensorclass( indexing, item assignment, reshaping, casting to device or storage and many others. - Args: + Keyword Args: autocast (bool, optional): if ``True``, the types indicated will be enforced when an argument is set. - Defaults to ``False``. + Thie argument is exclusive with ``autocast`` (both cannot be true at the same time). Defaults to ``False``. frozen (bool, optional): if ``True``, the content of the tensorclass cannot be modified. This argument is provided to dataclass-compatibility, a similar behavior can be obtained through the `lock` argument in the class constructor. Defaults to ``False``. + nocast (bool, optional): if ``True``, Tensor-compatible types such as ``int``, ``np.ndarray`` and the like + will not be cast to a tensor type. Thie argument is exclusive with ``autocast`` (both cannot be true + at the same time). Defaults to ``False``. tensorclass can be used with or without arguments: + Examples: >>> @tensorclass ... class X: - ... y: torch.Tensor - >>> X(1).y - 1 + ... y: int + >>> X(torch.ones(())).y + tensor(1.) >>> @tensorclass(autocast=False) ... class X: - ... y: torch.Tensor + ... y: int + >>> X(torch.ones(())).y + tensor(1.) + >>> @tensorclass(autocast=True) + ... class X: + ... y: int + >>> X(torch.ones(())).y + 1 + >>> @tensorclass(nocast=True) + ... class X: + ... y: Any >>> X(1).y 1 - >>> @tensorclass(autocast=True) + >>> @tensorclass(nocast=False) ... class X: - ... y: torch.Tensor + ... y: Any >>> X(1).y - torch.tensor(1) + tensor(1) Examples: >>> from tensordict import tensorclass @@ -456,7 +483,7 @@ def tensorclass( """ def wrap(cls): - return _tensorclass_dec(autocast, frozen)(cls) + return _tensorclass_dec(autocast, frozen, nocast)(cls) # See if we're being called as @tensorclass or @tensorclass(). if cls is None: @@ -1665,7 +1692,10 @@ def _is_castable(datatype): elif ( issubclass(value_type, torch.Tensor) or _is_tensor_collection(value_type) - or issubclass(value_type, (int, float, bool, np.ndarray)) + or ( + not cls.nocast + and issubclass(value_type, (int, float, bool, np.ndarray)) + ) ): return set_tensor() else: diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 469454ac0..32d5dab91 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -2155,6 +2155,54 @@ class AutoCastTensor: anything: Any +class TestNoCasting: + def test_nocast_int(self): + @tensorclass(nocast=False) + class X: + a: int # type is irrelevant + + assert isinstance(X(1).a, torch.Tensor) + + @tensorclass(nocast=True) + class X: + a: int # type is irrelevant + + assert isinstance(X(1).a, int) + + def test_nocast_np(self): + @tensorclass(nocast=False) + class X: + a: int # type is irrelevant + + assert isinstance(X(np.array([1])).a, torch.Tensor) + + @tensorclass(nocast=True) + class X: + a: int # type is irrelevant + + assert isinstance(X(np.array([1])).a, np.ndarray) + + def test_nocast_bool(self): + @tensorclass(nocast=False) + class X: + a: int # type is irrelevant + + assert isinstance(X(True).a, torch.Tensor) + + @tensorclass(nocast=True) + class X: + a: int # type is irrelevant + + assert isinstance(X(False).a, bool) + + def test_exclusivity(self): + with pytest.raises(ValueError, match="exclusive"): + + @tensorclass(nocast=True, autocast=True) + class X: + a: int # type is irrelevant + + class TestAutoCasting: @tensorclass(autocast=True) class ClsAutoCast: