Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 7, 2024
1 parent 850f35b commit 7785842
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 15 deletions.
60 changes: 45 additions & 15 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,28 +351,41 @@ def is_non_tensor(obj):


class _tensorclass_dec:
def __new__(cls, autocast: bool = False, frozen: bool = False):
def __new__(
cls, autocast: bool = False, frozen: bool = False, nocast: bool = False
):
if not isinstance(autocast, bool):
clz = autocast
self = super().__new__(cls)
self.__init__(autocast=False, frozen=False)
self.__init__(autocast=False, frozen=False, nocast=False)
return self.__call__(clz)
return super().__new__(cls)

def __init__(self, autocast: bool = False, frozen: bool = False):
def __init__(
self, autocast: bool = False, frozen: bool = False, nocast: bool = False
):
if autocast and nocast:
raise ValueError("autocast is exclusive with nocast.")
self.autocast = autocast
self.frozen = frozen
self.nocast = nocast

@dataclass_transform()
def __call__(self, cls: T) -> T:
clz = _tensorclass(cls, frozen=self.frozen)
clz.autocast = self.autocast
clz.nocast = self.nocast
return clz


@dataclass_transform()
def tensorclass(
cls: T = None, /, *, autocast: bool = False, frozen: bool = False
cls: T = None,
/,
*,
autocast: bool = False,
frozen: bool = False,
nocast: bool = False,
) -> T | None:
"""A decorator to create :obj:`tensorclass` classes.
Expand All @@ -381,30 +394,44 @@ def tensorclass(
indexing, item assignment, reshaping, casting to device or storage and many
others.
Args:
Keyword Args:
autocast (bool, optional): if ``True``, the types indicated will be enforced when an argument is set.
Defaults to ``False``.
Thie argument is exclusive with ``autocast`` (both cannot be true at the same time). Defaults to ``False``.
frozen (bool, optional): if ``True``, the content of the tensorclass cannot be modified. This argument is
provided to dataclass-compatibility, a similar behavior can be obtained through the `lock` argument in
the class constructor. Defaults to ``False``.
nocast (bool, optional): if ``True``, Tensor-compatible types such as ``int``, ``np.ndarray`` and the like
will not be cast to a tensor type. Thie argument is exclusive with ``autocast`` (both cannot be true
at the same time). Defaults to ``False``.
tensorclass can be used with or without arguments:
Examples:
>>> @tensorclass
... class X:
... y: torch.Tensor
>>> X(1).y
1
... y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=False)
... class X:
... y: torch.Tensor
... y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=True)
... class X:
... y: int
>>> X(torch.ones(())).y
1
>>> @tensorclass(nocast=True)
... class X:
... y: Any
>>> X(1).y
1
>>> @tensorclass(autocast=True)
>>> @tensorclass(nocast=False)
... class X:
... y: torch.Tensor
... y: Any
>>> X(1).y
torch.tensor(1)
tensor(1)
Examples:
>>> from tensordict import tensorclass
Expand Down Expand Up @@ -456,7 +483,7 @@ def tensorclass(
"""

def wrap(cls):
return _tensorclass_dec(autocast, frozen)(cls)
return _tensorclass_dec(autocast, frozen, nocast)(cls)

# See if we're being called as @tensorclass or @tensorclass().
if cls is None:
Expand Down Expand Up @@ -1665,7 +1692,10 @@ def _is_castable(datatype):
elif (
issubclass(value_type, torch.Tensor)
or _is_tensor_collection(value_type)
or issubclass(value_type, (int, float, bool, np.ndarray))
or (
not cls.nocast
and issubclass(value_type, (int, float, bool, np.ndarray))
)
):
return set_tensor()
else:
Expand Down
48 changes: 48 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2155,6 +2155,54 @@ class AutoCastTensor:
anything: Any


class TestNoCasting:
def test_nocast_int(self):
@tensorclass(nocast=False)
class X:
a: int # type is irrelevant

assert isinstance(X(1).a, torch.Tensor)

@tensorclass(nocast=True)
class X:
a: int # type is irrelevant

assert isinstance(X(1).a, int)

def test_nocast_np(self):
@tensorclass(nocast=False)
class X:
a: int # type is irrelevant

assert isinstance(X(np.array([1])).a, torch.Tensor)

@tensorclass(nocast=True)
class X:
a: int # type is irrelevant

assert isinstance(X(np.array([1])).a, np.ndarray)

def test_nocast_bool(self):
@tensorclass(nocast=False)
class X:
a: int # type is irrelevant

assert isinstance(X(True).a, torch.Tensor)

@tensorclass(nocast=True)
class X:
a: int # type is irrelevant

assert isinstance(X(False).a, bool)

def test_exclusivity(self):
with pytest.raises(ValueError, match="exclusive"):

@tensorclass(nocast=True, autocast=True)
class X:
a: int # type is irrelevant


class TestAutoCasting:
@tensorclass(autocast=True)
class ClsAutoCast:
Expand Down

0 comments on commit 7785842

Please sign in to comment.