From 8344930123e6cbfbfff46e6043f0e41a24726a09 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Nov 2024 17:47:53 +0100 Subject: [PATCH 1/6] Update [ghstack-poisoned] --- docs/source/reference/tensorclass.rst | 1 + tensordict/__init__.py | 1 + tensordict/_lazy.py | 26 ++++- tensordict/_td.py | 150 +++++++++++++++++++++++--- tensordict/base.py | 120 ++++++++++++++++++++- tensordict/nn/params.py | 8 +- tensordict/persistent.py | 28 ++++- tensordict/tensorclass.py | 104 +++++++++++++++++- tensordict/tensorclass.pyi | 7 ++ tensordict/utils.py | 2 +- test/_utils_internal.py | 58 ++++++---- test/test_tensorclass.py | 53 +++++++++ test/test_tensordict.py | 50 +++++++++ 13 files changed, 564 insertions(+), 44 deletions(-) diff --git a/docs/source/reference/tensorclass.rst b/docs/source/reference/tensorclass.rst index 17dceff06..ea55aef40 100644 --- a/docs/source/reference/tensorclass.rst +++ b/docs/source/reference/tensorclass.rst @@ -282,6 +282,7 @@ Here is an example: TensorClass NonTensorData NonTensorStack + from_dataclass Auto-casting ------------ diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 364a11f5a..7fc9d349d 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -43,6 +43,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.persistent import PersistentTensorDict from tensordict.tensorclass import ( + from_dataclass, NonTensorData, NonTensorStack, tensorclass, diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index eb4248671..fc87258bf 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -329,15 +329,37 @@ def _reduce_get_metadata(self): @classmethod def from_dict( cls, - input_dict, + input_dict: List[Dict[NestedKey, Any]], + *other, + auto_batch_size: bool = False, batch_size=None, device=None, batch_dims=None, stack_dim_name=None, stack_dim=0, ): + if batch_size is not None: + batch_size = list(batch_size) + if stack_dim is None: + stack_dim = 0 + n = batch_size.pop(stack_dim) + if n != len(input_dict): + raise ValueError( + "The number of dicts and the corresponding batch-size must match, " + f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}." + ) + batch_size = torch.Size(batch_size) return LazyStackedTensorDict( - *(input_dict[str(i)] for i in range(len(input_dict))), + *( + TensorDict.from_dict( + input_dict[str(i)], + *other, + auto_batch_size=auto_batch_size, + device=device, + batch_dims=batch_dims, + ) + for i in range(len(input_dict)) + ), stack_dim=stack_dim, stack_dim_name=stack_dim_name, ) diff --git a/tensordict/_td.py b/tensordict/_td.py index 7895fae4e..3029aeab0 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1957,8 +1957,46 @@ def _unsqueeze(tensor): @classmethod def from_dict( - cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None + cls, + input_dict, + *others, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + if device is not None: + raise TypeError( + "conflicting device values. Please use the keyword argument only." + ) + if batch_dims is not None: + raise TypeError( + "conflicting batch_dims values. Please use the keyword argument only." + ) + if names is not None: + raise TypeError( + "conflicting names values. Please use the keyword argument only." + ) + warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.", + category=DeprecationWarning, + ) + batch_size, *others = others + if len(others): + device, *others = others + if len(others): + batch_dims, *others = others + if len(others): + names, *others = others + if len(others): + raise TypeError("Too many positional arguments.") + if batch_dims is not None and batch_size is not None: raise ValueError( "Cannot pass both batch_size and batch_dims to `from_dict`." @@ -1967,12 +2005,12 @@ def from_dict( batch_size_set = torch.Size(()) if batch_size is None else batch_size input_dict = dict(input_dict) for key, value in list(input_dict.items()): - if isinstance(value, (dict,)): - # we don't know if another tensor of smaller size is coming - # so we can't be sure that the batch-size will still be valid later - input_dict[key] = TensorDict.from_dict( - value, batch_size=[], device=device, batch_dims=None - ) + # we don't know if another tensor of smaller size is coming + # so we can't be sure that the batch-size will still be valid later + input_dict[key] = TensorDict.from_any( + value, + auto_batch_size=False, + ) # regular __init__ breaks because a tensor may have the same batch-size as the tensordict out = cls( input_dict, @@ -1981,7 +2019,17 @@ def from_dict( names=names, ) if batch_size is None: - _set_max_batch_size(out, batch_dims) + if auto_batch_size is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + if auto_batch_size: + _set_max_batch_size(out, batch_dims) else: out.batch_size = batch_size return out @@ -1998,8 +2046,46 @@ def _from_dict_validated( ) def from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None, names=None + self, + input_dict, + *others, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + if device is not None: + raise TypeError( + "conflicting device values. Please use the keyword argument only." + ) + if batch_dims is not None: + raise TypeError( + "conflicting batch_dims values. Please use the keyword argument only." + ) + if names is not None: + raise TypeError( + "conflicting names values. Please use the keyword argument only." + ) + warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.", + category=DeprecationWarning, + ) + batch_size, *others = others + if len(others): + device, *others = others + if len(others): + batch_dims, *others = others + if len(others): + names, *others = others + if len(others): + raise TypeError("Too many positional arguments.") + if batch_dims is not None and batch_size is not None: raise ValueError( "Cannot pass both batch_size and batch_dims to `from_dict`." @@ -2014,14 +2100,24 @@ def from_dict_instance( cur_value = self.get(key, None) if cur_value is not None: input_dict[key] = cur_value.from_dict_instance( - value, batch_size=[], device=device, batch_dims=None + value, + device=device, + auto_batch_size=auto_batch_size, ) continue # we don't know if another tensor of smaller size is coming # so we can't be sure that the batch-size will still be valid later input_dict[key] = TensorDict.from_dict( - value, batch_size=[], device=device, batch_dims=None + value, + device=device, + auto_batch_size=auto_batch_size, + ) + else: + input_dict[key] = TensorDict.from_any( + value, + auto_batch_size=auto_batch_size, ) + out = TensorDict.from_dict( input_dict, batch_size=batch_size_set, @@ -2029,7 +2125,17 @@ def from_dict_instance( names=names, ) if batch_size is None: - _set_max_batch_size(out, batch_dims) + if auto_batch_size is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + if auto_batch_size: + _set_max_batch_size(out, batch_dims) else: out.batch_size = batch_size return out @@ -3857,7 +3963,14 @@ def expand(self, *args: int, inplace: bool = False) -> T: @classmethod def from_dict( - cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None + cls, + input_dict, + *others, + auto_batch_size: bool = False, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): raise NotImplementedError(f"from_dict not implemented for {cls.__name__}.") @@ -4273,6 +4386,12 @@ def _items( (key, tensordict._get_str(key, NO_DEFAULT)) for key in tensordict._source.keys() ) + from tensordict.persistent import PersistentTensorDict + + if isinstance(tensordict, PersistentTensorDict): + return ( + (key, tensordict._get_str(key, NO_DEFAULT)) for key in tensordict.keys() + ) raise NotImplementedError(type(tensordict)) def _keys(self) -> _TensorDictKeysView: @@ -4697,7 +4816,9 @@ def from_modules( ) -def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None): +def from_dict( + input_dict, *others, batch_size=None, device=None, batch_dims=None, names=None +): """Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`. If ``batch_size`` is not specified, returns the maximum batch size possible. @@ -4762,6 +4883,7 @@ def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=N """ return TensorDict.from_dict( input_dict, + *others, batch_size=batch_size, device=device, batch_dims=batch_dims, diff --git a/tensordict/base.py b/tensordict/base.py index 39729eba4..bea8af10c 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -12,6 +12,7 @@ import enum import gc import importlib +import importlib.util import os.path import queue import uuid @@ -112,6 +113,8 @@ except ImportError: from tensordict.utils import Buffer +_has_h5 = importlib.util.find_spec("h5py") is not None + # NO_DEFAULT is used as a placeholder whenever the default is not provided. # Using None is not an option since `td.get(key)` is a valid usage. @@ -1133,6 +1136,8 @@ def auto_device_(self) -> T: def from_dict( cls, input_dict, + *, + auto_batch_size: bool | None = None, batch_size: torch.Size | None = None, device: torch.device | None = None, batch_dims: int | None = None, @@ -1148,6 +1153,10 @@ def from_dict( Args: input_dict (dictionary, optional): a dictionary to use as a data source (nested keys compatible). + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. batch_size (iterable of int, optional): a batch size for the tensordict. device (torch.device or compatible type, optional): a device for the TensorDict. batch_dims (int, optional): the ``batch_dims`` (ie number of leading dimensions @@ -1213,6 +1222,7 @@ def _from_dict_validated(cls, *args, **kwargs): def from_dict_instance( self, input_dict, + *others, batch_size=None, device=None, batch_dims=None, @@ -9837,6 +9847,113 @@ def dict_to_namedtuple(dictionary): return dict_to_namedtuple(self.to_dict(retain_none=False)) + @classmethod + def from_any(cls, obj, *, auto_batch_size: bool = False): + """Converts any object to a TensorDict, recursively. + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. + + Support includes: + + - dataclasses through :meth:`~.from_dataclass` (dataclasses will be converted to TensorDict instances, not + tensorclasses). + - namedtuple through :meth:`~.from_namedtuple` + - dict through :meth:`~.from_dict` + - tuple through :meth:`~.from_tuple` + - numpy's structured arrays through :meth:`~.from_struct_array` + - h5 objects through :meth:`~.from_h5` + + """ + if isinstance(obj, dict): + return cls.from_dict(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): + return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) + from dataclasses import is_dataclass + + if is_dataclass(obj): + return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) + if is_namedtuple(obj): + return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, tuple): + return cls.from_tuple(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, list): + return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) + if _has_h5: + import h5py + + if isinstance(obj, h5py.File): + from tensordict.persistent import PersistentTensorDict + + obj = PersistentTensorDict(group=obj) + if auto_batch_size: + obj.auto_batch_size_() + return obj + return obj + + @classmethod + def from_tuple(cls, obj, *, auto_batch_size: bool = False): + from tensordict import TensorDict + + result = TensorDict({str(i): cls.from_any(item) for i, item in enumerate(obj)}) + if auto_batch_size: + result.auto_batch_size_() + return result + + @classmethod + def from_dataclass( + cls, dataclass, *, auto_batch_size: bool = False, as_tensorclass: bool = False + ): + """Converts a dataclass into a TensorDict instance. + + Args: + dataclass: The dataclass instance to be converted. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the + resulting TensorDict. Defaults to ``False``. + as_tensorclass (bool, optional): If ``True``, delegates the conversion to the free function + :func:`~tensordict.from_dataclass` and returns a tensor-compatible class + (:func:`~tensordict.tensorclass`) or instance instead of a ``TensorDict``. Defaults to ``False``. + + Returns: + A TensorDict instance derived from the provided dataclass, unless `as_tensorclass` is True, in which case a tensor-compatible class or instance is returned. + + Raises: + TypeError: If the provided input is not a dataclass instance. + + .. warning:: This method is distinct from the free function `from_dataclass` and serves a different purpose. + While the free function returns a tensor-compatible class or instance, this method returns a TensorDict instance. + + .. notes:: + + - This method creates a new TensorDict instance with keys corresponding to the fields of the input dataclass. + - Each key in the resulting TensorDict is initialized using the `cls.from_any` method. + - The `auto_batch_size` option allows for automatic batch size determination and application to the + resulting TensorDict. + + """ + if as_tensorclass: + from tensordict.tensorclass import from_dataclass + + return from_dataclass(dataclass, auto_batch_size=auto_batch_size) + from dataclasses import fields, is_dataclass + + from tensordict import TensorDict + + if not is_dataclass(dataclass): + raise TypeError( + f"Expected a dataclass input, got a {type(dataclass)} input instead." + ) + source = {} + for field in fields(dataclass): + source[field.name] = cls.from_any(getattr(dataclass, field.name)) + result = TensorDict(source) + if auto_batch_size: + result.auto_batch_size_() + return result + @classmethod def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): """Converts a namedtuple to a TensorDict recursively. @@ -9885,8 +10002,7 @@ def namedtuple_to_dict(namedtuple_obj): "indices": namedtuple_obj.indices, } for key, value in namedtuple_obj.items(): - if is_namedtuple(value): - namedtuple_obj[key] = namedtuple_to_dict(value) + namedtuple_obj[key] = cls.from_any(value) return dict(namedtuple_obj) result = TensorDict(namedtuple_to_dict(named_tuple)) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index df7bad0e6..0b6dca196 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -921,7 +921,13 @@ def _exclude( @_carry_over def from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None + self, + input_dict, + *, + auto_batch_size: bool = False, + batch_size=None, + device=None, + batch_dims=None, ): ... @_carry_over diff --git a/tensordict/persistent.py b/tensordict/persistent.py index d5f59110a..332023587 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -207,12 +207,25 @@ def from_h5(cls, filename, mode="r"): return out @classmethod - def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs): + def from_dict( + cls, + input_dict, + filename, + *others, + auto_batch_size: bool = False, + batch_size=None, + device=None, + **kwargs, + ): """Converts a dictionary or a TensorDict to a h5 file. Args: input_dict (dict, TensorDict or compatible): data to be stored as h5. filename (str or path): path to the h5 file. + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. batch_size (tensordict batch-size, optional): if provided, batch size of the tensordict. If not, the batch size will be gathered from the input structure (if present) or determined automatically. @@ -225,6 +238,19 @@ def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs) A :class:`PersitentTensorDict` instance linked to the newly created file. """ + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + warnings.warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead." + ) + if len(others) == 2: + batch_size, device = others + else: + batch_size = others[0] + import h5py file = h5py.File(filename, "w", locking=cls.LOCKING) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 2556729e5..fc188e7cf 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -237,6 +237,12 @@ def __subclasscheck__(self, subclass): "floor_", "frac", "frac_", + "from_any", + "from_dataclass", + "to_namedtuple", + "from_namedtuple", + "from_pytree", + "to_pytree", "gather", "isfinite", "isnan", @@ -379,6 +385,100 @@ def __call__(self, cls: T) -> T: return clz +def from_dataclass( + obj: Any, + *, + auto_batch_size: bool = False, + frozen: bool = False, + autocast: bool = False, + nocast: bool = False, +) -> Any: + """Converts a dataclass instance or a type into a tensorclass instance or type, respectively. + + This function takes a dataclass instance or a dataclass type and converts it into a tensor-compatible class, + optionally applying various configurations such as auto-batching, immutability, and type casting. + + Args: + obj (Any): The dataclass instance or type to be converted. If a type is provided, a new class is returned. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the resulting object. Defaults to ``False``. + frozen (bool, optional): If ``True``, the resulting class or instance will be immutable. Defaults to ``False``. + autocast (bool, optional): If ``True``, enables automatic type casting for the resulting class or instance. Defaults to ``False``. + nocast (bool, optional): If ``True``, disables any type casting for the resulting class or instance. Defaults to ``False``. + + Returns: + A tensor-compatible class or instance derived from the provided dataclass. + + Raises: + TypeError: If the provided input is not a dataclass instance or type. + + Examples: + >>> from dataclasses import dataclass + >>> import torch + >>> from tensordict.tensorclass import from_dataclass + >>> + >>> @dataclass + >>> class X: + ... a: int + ... b: torch.Tensor + ... + >>> x = X(0, 0) + >>> x2 = from_dataclass(x) + >>> print(x2) + X( + a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> X2 = from_dataclass(X, autocast=True) + >>> print(X2(a=0, b=0)) + X( + a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), + b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + batch_size=torch.Size([]), + device=None, + is_shared=False) + + .. notes:: If a dataclass type is provided, a new class is returned with the specified configurations. + If a dataclass instance is provided, a new instance of the tensor-compatible class is returned. + The `auto_batch_size`, `frozen`, `autocast`, and `nocast` options allow for flexible configuration of the resulting class or instance. + + .. warning:: Whereas :meth:`~tensordict.TensorDict.from_dataclass` will return a :class:`~tensordict.TensorDict` instance + by default, this method will return a tensorclass instance or type. + + """ + from dataclasses import asdict, is_dataclass, make_dataclass + + if isinstance(obj, type): + if is_tensorclass(obj): + return obj + cls = make_dataclass( + obj.__name__ + "_tc", fields=obj.__dataclass_fields__, bases=obj.__bases__ + ) + clz = _tensorclass(cls, frozen=frozen) + clz._type_hints = get_type_hints(obj) + clz._autocast = autocast + clz._nocast = nocast + clz._frozen = frozen + return clz + + if not is_dataclass(obj): + raise TypeError(f"Expected a obj input, got a {type(obj)} input instead.") + name = obj.__class__.__name__ + "_tc" + clz = _tensorclass( + make_dataclass(name, fields=obj.__dataclass_fields__), frozen=frozen + ) + clz._autocast = autocast + clz._nocast = nocast + clz._frozen = frozen + result = clz(**asdict(obj)) + if auto_batch_size: + result = result.auto_batch_size_() + return result + + @dataclass_transform() def tensorclass( cls: T = None, @@ -532,7 +632,8 @@ def __torch_function__( _is_non_tensor = getattr(cls, "_is_non_tensor", False) - cls = dataclass(cls, frozen=frozen) + if not dataclasses.is_dataclass(cls): + cls = dataclass(cls, frozen=frozen) _TENSORCLASS_MEMO[cls] = True expected_keys = cls.__expected_keys__ = set(cls.__dataclass_fields__) @@ -2483,7 +2584,6 @@ def __post_init__(self): data_inner = data.tolist() del _tensordict["data"] _non_tensordict["data"] = data_inner - # assert _tensordict.is_empty(), self._tensordict # TODO: this will probably fail with dynamo at some point, + it's terrible. # Make sure it's patched properly at init time diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 75678b4b6..a77ef185a 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -209,9 +209,16 @@ class TensorClass: def auto_batch_size_(self, batch_dims: int | None = None) -> T: ... def auto_device_(self) -> T: ... @classmethod + def from_dataclass( + cls, dataclass, *, auto_batch_size: bool = False, as_tensorclass: bool = False + ): ... + @classmethod + def from_any(cls, obj, *, auto_batch_size: bool = False): ... + @classmethod def from_dict( cls, input_dict, + *, batch_size: torch.Size | None = None, device: torch.device | None = None, batch_dims: int | None = None, diff --git a/tensordict/utils.py b/tensordict/utils.py index cdc0756f8..54e00a174 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -858,7 +858,7 @@ def is_tensorclass(obj: type | Any) -> bool: def _is_tensorclass(cls: type) -> bool: - out = _TENSORCLASS_MEMO.get(cls, None) + out = _TENSORCLASS_MEMO.get(cls) if out is None: out = getattr(cls, "_is_tensorclass", False) if not is_dynamo_compiling(): diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 8879f0e68..ad1a194cd 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -53,7 +53,8 @@ class TestTensorDictsBase: TYPES_DEVICES = [] TYPES_DEVICES_NOLAZY = [] - def td(self, device): + @classmethod + def td(cls, device): return TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -68,7 +69,8 @@ def td(self, device): TYPES_DEVICES += [["td", device]] TYPES_DEVICES_NOLAZY += [["td", device]] - def nested_td(self, device): + @classmethod + def nested_td(cls, device): return TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -86,7 +88,8 @@ def nested_td(self, device): TYPES_DEVICES += [["nested_td", device]] TYPES_DEVICES_NOLAZY += [["nested_td", device]] - def nested_tensorclass(self, device): + @classmethod + def nested_tensorclass(cls, device): nested_class = MyClass( X=torch.randn(4, 3, 2, 1), @@ -119,8 +122,9 @@ def nested_tensorclass(self, device): TYPES_DEVICES += [["nested_tensorclass", device]] TYPES_DEVICES_NOLAZY += [["nested_tensorclass", device]] + @classmethod @set_lazy_legacy(True) - def nested_stacked_td(self, device): + def nested_stacked_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -140,8 +144,9 @@ def nested_stacked_td(self, device): TYPES_DEVICES += [["nested_stacked_td", device]] TYPES_DEVICES_NOLAZY += [["nested_stacked_td", device]] + @classmethod @set_lazy_legacy(True) - def stacked_td(self, device): + def stacked_td(cls, device): td1 = TensorDict( source={ "a": torch.randn(4, 3, 1, 5), @@ -165,7 +170,8 @@ def stacked_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["stacked_td", device]] - def idx_td(self, device): + @classmethod + def idx_td(cls, device): td = TensorDict( source={ "a": torch.randn(2, 4, 3, 2, 1, 5), @@ -180,7 +186,8 @@ def idx_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["idx_td", device]] - def sub_td(self, device): + @classmethod + def sub_td(cls, device): td = TensorDict( source={ "a": torch.randn(2, 4, 3, 2, 1, 5), @@ -195,7 +202,8 @@ def sub_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["sub_td", device]] - def sub_td2(self, device): + @classmethod + def sub_td2(cls, device): td = TensorDict( source={ "a": torch.randn(4, 2, 3, 2, 1, 5), @@ -212,17 +220,19 @@ def sub_td2(self, device): temp_path_memmap = tempfile.TemporaryDirectory() - def memmap_td(self, device): - path = pathlib.Path(self.temp_path_memmap.name) + @classmethod + def memmap_td(cls, device): + path = pathlib.Path(cls.temp_path_memmap.name) shutil.rmtree(path) path.mkdir() - return self.td(device).memmap_(path) + return cls.td(device).memmap_(path) TYPES_DEVICES += [["memmap_td", torch.device("cpu")]] TYPES_DEVICES_NOLAZY += [["memmap_td", torch.device("cpu")]] + @classmethod @set_lazy_legacy(True) - def permute_td(self, device): + def permute_td(cls, device): return TensorDict( source={ "a": torch.randn(3, 1, 4, 2, 5), @@ -236,8 +246,9 @@ def permute_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["permute_td", device]] + @classmethod @set_lazy_legacy(True) - def unsqueezed_td(self, device): + def unsqueezed_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 5), @@ -252,8 +263,9 @@ def unsqueezed_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["unsqueezed_td", device]] + @classmethod @set_lazy_legacy(True) - def squeezed_td(self, device): + def squeezed_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 1, 2, 1, 5), @@ -268,7 +280,8 @@ def squeezed_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["squeezed_td", device]] - def td_reset_bs(self, device): + @classmethod + def td_reset_bs(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -285,13 +298,14 @@ def td_reset_bs(self, device): TYPES_DEVICES += [["td_reset_bs", device]] TYPES_DEVICES_NOLAZY += [["td_reset_bs", device]] + @classmethod def td_h5( - self, + cls, device, ): file = tempfile.NamedTemporaryFile() filename = file.name - nested_td = self.nested_td(device) + nested_td = cls.nested_td(device) td_h5 = PersistentTensorDict.from_dict( nested_td, filename=filename, device=device ) @@ -303,15 +317,17 @@ def td_h5( TYPES_DEVICES += [["td_h5", device]] TYPES_DEVICES_NOLAZY += [["td_h5", device]] - def td_params(self, device): - return TensorDictParams(self.td(device)) + @classmethod + def td_params(cls, device): + return TensorDictParams(cls.td(device)) for device in get_available_devices(): TYPES_DEVICES += [["td_params", device]] TYPES_DEVICES_NOLAZY += [["td_params", device]] - def td_with_non_tensor(self, device): - td = self.td(device) + @classmethod + def td_with_non_tensor(cls, device): + td = cls.td(device) return td.set_non_tensor( ("data", "non_tensor"), # this is allowed since nested NonTensorData are automatically unwrapped diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 4753c3704..127d4b77a 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -23,6 +23,7 @@ import tensordict.utils import torch from tensordict import TensorClass +from tensordict.tensorclass import from_dataclass try: import torchsnapshot @@ -94,6 +95,21 @@ class MyData2: z: list +@dataclasses.dataclass +class MyDataClass: + a: int + b: torch.Tensor + c: str + + +try: + MyTensorClass_autocast = from_dataclass(MyDataClass, autocast=True) + MyTensorClass_nocast = from_dataclass(MyDataClass, nocast=True) + MyTensorClass = from_dataclass(MyDataClass) +except Exception: + MyTensorClass_autocast = MyTensorClass_nocast = MyTensorClass = None + + class TestTensorClass: def test_all_any(self): @tensorclass @@ -517,6 +533,43 @@ class MyClass2: assert (a != c.clone().zero_()).any() assert (c != a.clone().zero_()).any() + def test_from_dataclass(self): + assert is_tensorclass(MyTensorClass_autocast) + assert MyTensorClass_nocast is not MyDataClass + assert MyTensorClass_autocast._autocast + x = MyTensorClass_autocast(a=0, b=0, c=0) + assert isinstance(x.a, int) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, str) + + assert is_tensorclass(MyTensorClass_nocast) + assert MyTensorClass_nocast is not MyTensorClass_autocast + assert MyTensorClass_nocast._nocast + + x = MyTensorClass_nocast(a=0, b=0, c=0) + assert is_tensorclass(MyTensorClass) + assert not MyTensorClass._autocast + assert not MyTensorClass._nocast + assert isinstance(x.a, int) + assert isinstance(x.b, int) + assert isinstance(x.c, int) + + x = MyTensorClass(a=0, b=0, c=0) + assert isinstance(x.a, torch.Tensor) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, torch.Tensor) + + x = TensorDict.from_dataclass(MyTensorClass(a=0, b=0, c=0)) + assert isinstance(x, TensorDict) + assert isinstance(x["a"], torch.Tensor) + assert isinstance(x["b"], torch.Tensor) + assert isinstance(x["c"], torch.Tensor) + x = from_dataclass(MyTensorClass(a=0, b=0, c=0)) + assert is_tensorclass(x) + assert isinstance(x.a, torch.Tensor) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, torch.Tensor) + def test_from_dict(self): td = TensorDict( { diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 257c2b712..7d4148cd9 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -20,6 +20,7 @@ import warnings from dataclasses import dataclass from pathlib import Path +from typing import Any import numpy as np import pytest @@ -949,6 +950,55 @@ def test_fromkeys(self): td = TensorDict.fromkeys({"a", "b", "c"}, 1) assert td["a"] == 1 + def test_from_any(self): + from dataclasses import dataclass + + @dataclass + class MyClass: + a: int + + pytree = ( + [torch.randint(10, (3,)), torch.zeros(2)], + { + "tensor": torch.randn( + 2, + ), + "td": TensorDict({"one": 1}), + "tuple": (1, 2, 3), + }, + {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, + {"dataclass": MyClass(a=0)}, + ) + if _has_h5py: + pytree = pytree + ({"h5py": TestTensorDictsBase.td_h5(device="cpu").file},) + td = TensorDict.from_any(pytree) + assert set(td.keys(True, True)) == { + ("0", "0"), + ("0", "1"), + ("1", "td", "one"), + ("1", "tensor"), + ("1", "tuple", "0"), + ("1", "tuple", "1"), + ("1", "tuple", "2"), + ("2", "named_tuple", "two"), + ("4", "h5py", "a"), + ("4", "h5py", "b"), + ("4", "h5py", "c"), + ("4", "h5py", "my_nested_td", "inner"), + } + + def test_from_dataclass(self): + @dataclass + class MyClass: + a: int + b: Any + + obj = MyClass(a=0, b=1) + obj_td = TensorDict.from_dataclass(obj) + obj_tc = TensorDict.from_dataclass(obj, as_tensorclass=True) + assert is_tensorclass(obj_tc) + assert not is_tensorclass(obj_td) + @pytest.mark.parametrize("batch_size", [None, [3, 4]]) @pytest.mark.parametrize("batch_dims", [None, 1, 2]) @pytest.mark.parametrize("device", get_available_devices()) From 6579f5eefa53e717d46119006d1fdaeab12134a4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Nov 2024 20:28:28 +0100 Subject: [PATCH 2/6] Update [ghstack-poisoned] --- tensordict/_lazy.py | 23 ++++++++++--------- tensordict/_reductions.py | 1 + tensordict/_td.py | 48 +++++++++++++++++++++++---------------- tensordict/base.py | 5 ++++ tensordict/functional.py | 5 +++- tensordict/tensorclass.py | 30 +++++++++++++++++------- test/test_tensordict.py | 48 +++++++++++++++++++++++---------------- 7 files changed, 101 insertions(+), 59 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index fc87258bf..73c316981 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -338,17 +338,17 @@ def from_dict( stack_dim_name=None, stack_dim=0, ): - if batch_size is not None: - batch_size = list(batch_size) - if stack_dim is None: - stack_dim = 0 - n = batch_size.pop(stack_dim) - if n != len(input_dict): - raise ValueError( - "The number of dicts and the corresponding batch-size must match, " - f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}." - ) - batch_size = torch.Size(batch_size) + # if batch_size is not None: + # batch_size = list(batch_size) + # if stack_dim is None: + # stack_dim = 0 + # n = batch_size.pop(stack_dim) + # if n != len(input_dict): + # raise ValueError( + # "The number of dicts and the corresponding batch-size must match, " + # f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}." + # ) + # batch_size = torch.Size(batch_size) return LazyStackedTensorDict( *( TensorDict.from_dict( @@ -357,6 +357,7 @@ def from_dict( auto_batch_size=auto_batch_size, device=device, batch_dims=batch_dims, + batch_size=batch_size, ) for i in range(len(input_dict)) ), diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index be8aa42f1..0143ec856 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -121,6 +121,7 @@ def from_metadata(metadata=metadata, prefix=None): d[k] = from_metadata( v, prefix=prefix + (k,) if prefix is not None else (k,) ) + print('cls_metadata', cls_metadata) result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) if is_locked: result = result.lock_() diff --git a/tensordict/_td.py b/tensordict/_td.py index 3029aeab0..66f4dc86b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -615,7 +615,7 @@ def __ne__(self, other: object) -> T | bool: if is_tensorclass(other): return other != self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -639,7 +639,7 @@ def __xor__(self, other: object) -> T | bool: if is_tensorclass(other): return other ^ self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -663,7 +663,7 @@ def __or__(self, other: object) -> T | bool: if is_tensorclass(other): return other | self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -687,7 +687,7 @@ def __eq__(self, other: object) -> T | bool: if is_tensorclass(other): return other == self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -709,7 +709,7 @@ def __ge__(self, other: object) -> T | bool: if is_tensorclass(other): return other <= self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -731,7 +731,7 @@ def __gt__(self, other: object) -> T | bool: if is_tensorclass(other): return other < self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -753,7 +753,7 @@ def __le__(self, other: object) -> T | bool: if is_tensorclass(other): return other >= self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -775,7 +775,7 @@ def __lt__(self, other: object) -> T | bool: if is_tensorclass(other): return other > self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -2019,7 +2019,7 @@ def from_dict( names=names, ) if batch_size is None: - if auto_batch_size is None: + if auto_batch_size is None and batch_dims is None: warn( "The batch-size was not provided and auto_batch_size isn't set either. " "Currently, from_dict will call set auto_batch_size=True but this behaviour " @@ -2028,6 +2028,8 @@ def from_dict( category=DeprecationWarning, ) auto_batch_size = True + elif auto_batch_size is None: + auto_batch_size = True if auto_batch_size: _set_max_batch_size(out, batch_dims) else: @@ -2099,23 +2101,26 @@ def from_dict_instance( # TODO: v0.7: remove the None cur_value = self.get(key, None) if cur_value is not None: + print(type(cur_value)) input_dict[key] = cur_value.from_dict_instance( value, device=device, - auto_batch_size=auto_batch_size, + auto_batch_size=False, ) + print(type(cur_value), type(input_dict[key])) continue - # we don't know if another tensor of smaller size is coming - # so we can't be sure that the batch-size will still be valid later - input_dict[key] = TensorDict.from_dict( - value, - device=device, - auto_batch_size=auto_batch_size, - ) + else: + # we don't know if another tensor of smaller size is coming + # so we can't be sure that the batch-size will still be valid later + input_dict[key] = TensorDict.from_dict( + value, + device=device, + auto_batch_size=False, + ) else: input_dict[key] = TensorDict.from_any( value, - auto_batch_size=auto_batch_size, + auto_batch_size=False, ) out = TensorDict.from_dict( @@ -2125,7 +2130,7 @@ def from_dict_instance( names=names, ) if batch_size is None: - if auto_batch_size is None: + if auto_batch_size is None and batch_dims is None: warn( "The batch-size was not provided and auto_batch_size isn't set either. " "Currently, from_dict will call set auto_batch_size=True but this behaviour " @@ -2134,8 +2139,13 @@ def from_dict_instance( category=DeprecationWarning, ) auto_batch_size = True + elif auto_batch_size is None: + auto_batch_size = True if auto_batch_size: + print('self', self) + print('out', out) _set_max_batch_size(out, batch_dims) + print('out', out) else: out.batch_size = batch_size return out diff --git a/tensordict/base.py b/tensordict/base.py index bea8af10c..b9a4077d5 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1216,6 +1216,8 @@ def _from_dict_validated(cls, *args, **kwargs): By default, falls back on :meth:`~.from_dict`. """ + kwargs.setdefault("auto_batch_size", True) + print('kwargs', kwargs) return cls.from_dict(*args, **kwargs) @abc.abstractmethod @@ -1223,6 +1225,7 @@ def from_dict_instance( self, input_dict, *others, + auto_batch_size: bool | None=None, batch_size=None, device=None, batch_dims=None, @@ -9866,6 +9869,8 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): - h5 objects through :meth:`~.from_h5` """ + if is_tensor_collection(obj): + return obj if isinstance(obj, dict): return cls.from_dict(obj, auto_batch_size=auto_batch_size) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): diff --git a/tensordict/functional.py b/tensordict/functional.py index a40095141..edd93a36f 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -437,6 +437,7 @@ def make_tensordict( input_dict: dict[str, CompatibleType] | None = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, + auto_batch_size:bool|None=None, **kwargs: CompatibleType, # source ) -> TensorDict: """Returns a TensorDict created from the keyword arguments or an input dictionary. @@ -453,6 +454,8 @@ def make_tensordict( (incompatible with nested keys). batch_size (iterable of int, optional): a batch size for the tensordict. device (torch.device or compatible type, optional): a device for the TensorDict. + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. Examples: >>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} @@ -500,4 +503,4 @@ def make_tensordict( """ if input_dict is not None: kwargs.update(input_dict) - return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device) + return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index fc188e7cf..07b6d5faa 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -27,6 +27,7 @@ from textwrap import indent from typing import Any, Callable, get_type_hints, List, Sequence, Type, TypeVar +from warnings import warn import numpy as np import orjson as json @@ -632,8 +633,9 @@ def __torch_function__( _is_non_tensor = getattr(cls, "_is_non_tensor", False) - if not dataclasses.is_dataclass(cls): - cls = dataclass(cls, frozen=frozen) + # Breaks some tests, don't do that: + # if not dataclasses.is_dataclass(cls): + cls = dataclass(cls, frozen=frozen) _TENSORCLASS_MEMO[cls] = True expected_keys = cls.__expected_keys__ = set(cls.__dataclass_fields__) @@ -1367,7 +1369,7 @@ def _update( non_blocking: bool = False, ): if isinstance(input_dict_or_td, dict): - input_dict_or_td = self.from_dict(input_dict_or_td) + input_dict_or_td = self.from_dict(input_dict_or_td, auto_batch_size=False) if is_tensorclass(input_dict_or_td): non_tensordict = { @@ -1579,7 +1581,7 @@ def _to_dict(self, *, retain_none: bool = True) -> dict: return td_dict -def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): +def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None): # we pass through a tensordict because keys could be passed as NestedKeys # We can't assume all keys are strings, otherwise calling cls(**kwargs) # would work ok @@ -1593,7 +1595,7 @@ def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): non_tensordict=input_dict, ) td = TensorDict.from_dict( - input_dict, batch_size=batch_size, device=device, batch_dims=batch_dims + input_dict, batch_size=batch_size, device=device, batch_dims=batch_dims, auto_batch_size=auto_batch_size ) non_tensordict = {} @@ -1601,7 +1603,7 @@ def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): def _from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None + self, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None ): if batch_dims is not None and batch_size is not None: raise ValueError("Cannot pass both batch_size and batch_dims to `from_dict`.") @@ -1611,7 +1613,7 @@ def _from_dict_instance( # TODO: this is a bit slow and will be a bottleneck every time td[idx] = dict(subtd) # is called when there are non tensor data in it if not _is_tensor_collection(type(input_dict)): - input_tdict = TensorDict.from_dict(input_dict) + input_tdict = TensorDict.from_dict(input_dict, auto_batch_size=auto_batch_size) else: input_tdict = input_dict trsf_dict = {} @@ -1639,7 +1641,19 @@ def _from_dict_instance( ) # check that if batch_size is None: - out._tensordict.auto_batch_size_() + if auto_batch_size is None and batch_dims is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + elif auto_batch_size is None: + auto_batch_size = True + if auto_batch_size: + out.auto_batch_size_() return out diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 7d4148cd9..efe47b480 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -972,7 +972,7 @@ class MyClass: if _has_h5py: pytree = pytree + ({"h5py": TestTensorDictsBase.td_h5(device="cpu").file},) td = TensorDict.from_any(pytree) - assert set(td.keys(True, True)) == { + expected = { ("0", "0"), ("0", "1"), ("1", "td", "one"), @@ -981,11 +981,18 @@ class MyClass: ("1", "tuple", "1"), ("1", "tuple", "2"), ("2", "named_tuple", "two"), - ("4", "h5py", "a"), - ("4", "h5py", "b"), - ("4", "h5py", "c"), - ("4", "h5py", "my_nested_td", "inner"), + ("3", "dataclass", "a"), } + if _has_h5py: + expected = expected.union( + { + ("4", "h5py", "a"), + ("4", "h5py", "b"), + ("4", "h5py", "c"), + ("4", "h5py", "my_nested_td", "inner"), + } + ) + assert set(td.keys(True, True)) == expected, set(td.keys(True, True)).symmetric_difference(expected) def test_from_dataclass(self): @dataclass @@ -1017,7 +1024,7 @@ def test_from_dict(self, batch_size, batch_dims, device): ) return data = TensorDict.from_dict( - data, batch_size=batch_size, batch_dims=batch_dims, device=device + data, batch_size=batch_size, batch_dims=batch_dims, device=device, auto_batch_size=True ) assert data.device == device assert "a" in data.keys() @@ -1051,7 +1058,7 @@ class MyClass: assert isinstance(td_dict["b"]["y"], int) assert isinstance(td_dict["b"]["z"], dict) assert isinstance(td_dict["b"]["z"]["y"], int) - td_recon = td.from_dict_instance(td_dict) + td_recon = td.from_dict_instance(td_dict, auto_batch_size=True) assert isinstance(td_recon["a"], torch.Tensor) assert isinstance(td_recon["b"], MyClass) assert isinstance(td_recon["b"].x, torch.Tensor) @@ -6493,7 +6500,7 @@ def recursive_checker(cur_dict): assert recursive_checker(td_dict) if td_name == "td_with_non_tensor": assert td_dict["data"]["non_tensor"] == "some text data" - assert (TensorDict.from_dict(td_dict) == td).all() + assert (TensorDict.from_dict(td_dict,auto_batch_size=False) == td).all() def test_to_namedtuple(self, td_name, device): def is_namedtuple(obj): @@ -7821,7 +7828,7 @@ def test_mp(self, td_type, unbind_as): class TestMakeTensorDict: def test_create_tensordict(self): - tensordict = make_tensordict(a=torch.zeros(3, 4)) + tensordict = make_tensordict(a=torch.zeros(3, 4), auto_batch_size=True) assert (tensordict["a"] == torch.zeros(3, 4)).all() def test_nested(self): @@ -7829,7 +7836,7 @@ def test_nested(self): "a": {"b": torch.randn(3, 4), "c": torch.randn(3, 4, 5)}, "d": torch.randn(3), } - tensordict = make_tensordict(input_dict) + tensordict = make_tensordict(input_dict, auto_batch_size=True) assert tensordict.shape == torch.Size([3]) assert tensordict["a"].shape == torch.Size([3, 4]) input_tensordict = TensorDict( @@ -7839,7 +7846,7 @@ def test_nested(self): }, [], ) - tensordict = make_tensordict(input_tensordict) + tensordict = make_tensordict(input_tensordict, auto_batch_size=True) assert tensordict.shape == torch.Size([3]) assert tensordict["a"].shape == torch.Size([3, 4]) input_dict = { @@ -7847,30 +7854,30 @@ def test_nested(self): ("a", "c"): torch.randn(3, 4, 5), "d": torch.randn(3), } - tensordict = make_tensordict(input_dict) + tensordict = make_tensordict(input_dict, auto_batch_size=True) assert tensordict.shape == torch.Size([3]) assert tensordict["a"].shape == torch.Size([3, 4]) def test_tensordict_batch_size(self): - tensordict = make_tensordict() + tensordict = make_tensordict(auto_batch_size=True) assert tensordict.batch_size == torch.Size([]) - tensordict = make_tensordict(a=torch.randn(3, 4)) + tensordict = make_tensordict(a=torch.randn(3, 4), auto_batch_size=True) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(3, 4, 5)) + tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(3, 4, 5), auto_batch_size=True) assert tensordict.batch_size == torch.Size([3, 4]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(3, 5)) # nested + nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(3, 5), auto_batch_size=True) # nested assert nested_tensordict.batch_size == torch.Size([3]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(4, 5)) # nested + nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(4, 5), auto_batch_size=True) # nested assert nested_tensordict.batch_size == torch.Size([]) - tensordict = make_tensordict(a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5)) + tensordict = make_tensordict(a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5), auto_batch_size=True) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(1)) + tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(1), auto_batch_size=True) assert tensordict.batch_size == torch.Size([]) tensordict = make_tensordict( @@ -7886,7 +7893,7 @@ def test_tensordict_batch_size(self): @pytest.mark.parametrize("device", get_available_devices()) def test_tensordict_device(self, device): tensordict = make_tensordict( - a=torch.randn(3, 4), b=torch.randn(3, 4), device=device + a=torch.randn(3, 4), b=torch.randn(3, 4), device=device, auto_batch_size=True ) assert tensordict.device == device assert tensordict["a"].device == device @@ -7897,6 +7904,7 @@ def test_tensordict_device(self, device): b=torch.randn(3, 4), c=torch.randn(3, 4, device="cpu"), device=device, + auto_batch_size=True, ) assert tensordict.device == device assert tensordict["a"].device == device From c01404cf07c2cf252287090023b883c458109bb4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 22 Nov 2024 20:44:33 +0100 Subject: [PATCH 3/6] Update [ghstack-poisoned] --- tensordict/_reductions.py | 1 - tensordict/_td.py | 5 ----- tensordict/base.py | 3 +-- tensordict/functional.py | 6 ++++-- tensordict/nn/common.py | 4 +++- tensordict/tensorclass.py | 26 ++++++++++++++++++++++---- test/test_tensorclass.py | 10 +++++++--- test/test_tensordict.py | 37 ++++++++++++++++++++++++++++--------- 8 files changed, 65 insertions(+), 27 deletions(-) diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 0143ec856..be8aa42f1 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -121,7 +121,6 @@ def from_metadata(metadata=metadata, prefix=None): d[k] = from_metadata( v, prefix=prefix + (k,) if prefix is not None else (k,) ) - print('cls_metadata', cls_metadata) result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) if is_locked: result = result.lock_() diff --git a/tensordict/_td.py b/tensordict/_td.py index 66f4dc86b..07a98cdfb 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2101,13 +2101,11 @@ def from_dict_instance( # TODO: v0.7: remove the None cur_value = self.get(key, None) if cur_value is not None: - print(type(cur_value)) input_dict[key] = cur_value.from_dict_instance( value, device=device, auto_batch_size=False, ) - print(type(cur_value), type(input_dict[key])) continue else: # we don't know if another tensor of smaller size is coming @@ -2142,10 +2140,7 @@ def from_dict_instance( elif auto_batch_size is None: auto_batch_size = True if auto_batch_size: - print('self', self) - print('out', out) _set_max_batch_size(out, batch_dims) - print('out', out) else: out.batch_size = batch_size return out diff --git a/tensordict/base.py b/tensordict/base.py index b9a4077d5..3666b6772 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1217,7 +1217,6 @@ def _from_dict_validated(cls, *args, **kwargs): By default, falls back on :meth:`~.from_dict`. """ kwargs.setdefault("auto_batch_size", True) - print('kwargs', kwargs) return cls.from_dict(*args, **kwargs) @abc.abstractmethod @@ -1225,7 +1224,7 @@ def from_dict_instance( self, input_dict, *others, - auto_batch_size: bool | None=None, + auto_batch_size: bool | None = None, batch_size=None, device=None, batch_dims=None, diff --git a/tensordict/functional.py b/tensordict/functional.py index edd93a36f..2699f36bb 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -437,7 +437,7 @@ def make_tensordict( input_dict: dict[str, CompatibleType] | None = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, - auto_batch_size:bool|None=None, + auto_batch_size: bool | None = None, **kwargs: CompatibleType, # source ) -> TensorDict: """Returns a TensorDict created from the keyword arguments or an input dictionary. @@ -503,4 +503,6 @@ def make_tensordict( """ if input_dict is not None: kwargs.update(input_dict) - return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size) + return TensorDict.from_dict( + kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size + ) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 0b55d1cef..ffedba9ad 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -297,9 +297,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: f"The key {expected_key} wasn't found in the keyword arguments " f"but is expected to execute that function." ) + batch_size = torch.Size([]) if not self.auto_batch_size else None tensordict = make_tensordict( tensordict_values, - batch_size=torch.Size([]) if not self.auto_batch_size else None, + batch_size=batch_size, + auto_batch_size=False, ) if _self is not None: out = func(_self, tensordict, *args, **kwargs) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 07b6d5faa..1fb7d3049 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1581,7 +1581,15 @@ def _to_dict(self, *, retain_none: bool = True) -> dict: return td_dict -def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None): +def _from_dict( + cls, + input_dict, + *, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, +): # we pass through a tensordict because keys could be passed as NestedKeys # We can't assume all keys are strings, otherwise calling cls(**kwargs) # would work ok @@ -1595,7 +1603,11 @@ def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=No non_tensordict=input_dict, ) td = TensorDict.from_dict( - input_dict, batch_size=batch_size, device=device, batch_dims=batch_dims, auto_batch_size=auto_batch_size + input_dict, + batch_size=batch_size, + device=device, + batch_dims=batch_dims, + auto_batch_size=auto_batch_size, ) non_tensordict = {} @@ -1603,7 +1615,13 @@ def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=No def _from_dict_instance( - self, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None + self, + input_dict, + *, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, ): if batch_dims is not None and batch_size is not None: raise ValueError("Cannot pass both batch_size and batch_dims to `from_dict`.") @@ -1773,7 +1791,7 @@ def _is_castable(datatype): if isinstance(value, dict): if _is_tensor_collection(target_cls): - cast_val = target_cls.from_dict(value) + cast_val = target_cls.from_dict(value, auto_batch_size=False) self._tensordict.set( key, cast_val, inplace=inplace, non_blocking=non_blocking ) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 127d4b77a..0f71bd743 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -584,7 +584,7 @@ def test_from_dict(self): class MyClass: a: TensorDictBase - tc = MyClass.from_dict(d) + tc = MyClass.from_dict(d, auto_batch_size=True) assert isinstance(tc, MyClass) assert isinstance(tc.a, TensorDict) assert tc.batch_size == torch.Size([10]) @@ -2148,7 +2148,9 @@ class TestClass: my_tensor=torch.tensor([1, 2, 3]), my_str="hello", batch_size=[3] ) - assert (test_class == TestClass.from_dict(test_class.to_dict())).all() + assert ( + test_class == TestClass.from_dict(test_class.to_dict(), auto_batch_size=True) + ).all() # Currently we don't test non-tensor in __eq__ because __eq__ can break with arrays and such # test_class2 = TestClass( @@ -2161,7 +2163,9 @@ class TestClass: my_tensor=torch.tensor([1, 2, 0]), my_str="hello", batch_size=[3] ) - assert not (test_class == TestClass.from_dict(test_class3.to_dict())).all() + assert not ( + test_class == TestClass.from_dict(test_class3.to_dict(), auto_batch_size=True) + ).all() @tensorclass(autocast=True) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index efe47b480..372e1af6a 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -992,7 +992,9 @@ class MyClass: ("4", "h5py", "my_nested_td", "inner"), } ) - assert set(td.keys(True, True)) == expected, set(td.keys(True, True)).symmetric_difference(expected) + assert set(td.keys(True, True)) == expected, set( + td.keys(True, True) + ).symmetric_difference(expected) def test_from_dataclass(self): @dataclass @@ -1024,7 +1026,11 @@ def test_from_dict(self, batch_size, batch_dims, device): ) return data = TensorDict.from_dict( - data, batch_size=batch_size, batch_dims=batch_dims, device=device, auto_batch_size=True + data, + batch_size=batch_size, + batch_dims=batch_dims, + device=device, + auto_batch_size=True, ) assert data.device == device assert "a" in data.keys() @@ -6500,7 +6506,7 @@ def recursive_checker(cur_dict): assert recursive_checker(td_dict) if td_name == "td_with_non_tensor": assert td_dict["data"]["non_tensor"] == "some text data" - assert (TensorDict.from_dict(td_dict,auto_batch_size=False) == td).all() + assert (TensorDict.from_dict(td_dict, auto_batch_size=False) == td).all() def test_to_namedtuple(self, td_name, device): def is_namedtuple(obj): @@ -7865,19 +7871,29 @@ def test_tensordict_batch_size(self): tensordict = make_tensordict(a=torch.randn(3, 4), auto_batch_size=True) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(3, 4, 5), auto_batch_size=True) + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(3, 4, 5), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([3, 4]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(3, 5), auto_batch_size=True) # nested + nested_tensordict = make_tensordict( + c=tensordict, d=torch.randn(3, 5), auto_batch_size=True + ) # nested assert nested_tensordict.batch_size == torch.Size([3]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(4, 5), auto_batch_size=True) # nested + nested_tensordict = make_tensordict( + c=tensordict, d=torch.randn(4, 5), auto_batch_size=True + ) # nested assert nested_tensordict.batch_size == torch.Size([]) - tensordict = make_tensordict(a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5), auto_batch_size=True) + tensordict = make_tensordict( + a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(1), auto_batch_size=True) + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(1), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([]) tensordict = make_tensordict( @@ -7893,7 +7909,10 @@ def test_tensordict_batch_size(self): @pytest.mark.parametrize("device", get_available_devices()) def test_tensordict_device(self, device): tensordict = make_tensordict( - a=torch.randn(3, 4), b=torch.randn(3, 4), device=device, auto_batch_size=True + a=torch.randn(3, 4), + b=torch.randn(3, 4), + device=device, + auto_batch_size=True, ) assert tensordict.device == device assert tensordict["a"].device == device From 6818c8c39cbd94003daba0a9a49e6a3f209613b6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 23 Nov 2024 14:44:57 +0100 Subject: [PATCH 4/6] Update [ghstack-poisoned] --- tensordict/base.py | 13 +++++++------ tensordict/tensorclass.py | 3 ++- tensordict/utils.py | 11 +++++++++++ test/test_tensordict.py | 5 +++-- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 3666b6772..3dea5a7d0 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -55,6 +55,7 @@ _CloudpickleWrapper, _DTYPE2STRDTYPE, _GENERIC_NESTED_ERR, + _is_dataclass as is_dataclass, _is_non_tensor, _is_number, _is_tensorclass, @@ -9452,6 +9453,7 @@ def _validate_value( if device is not None and value.device != device: if _device_recorder.marked and device.type != "cuda": _device_recorder.record_transfer(device) + assert not non_blocking value = value.to(device, non_blocking=non_blocking) if check_shape: if is_tc is None: @@ -9874,16 +9876,15 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return cls.from_dict(obj, auto_batch_size=auto_batch_size) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) - from dataclasses import is_dataclass - if is_dataclass(obj): - return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) - if is_namedtuple(obj): - return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) if isinstance(obj, tuple): return cls.from_tuple(obj, auto_batch_size=auto_batch_size) if isinstance(obj, list): return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) + if is_dataclass(obj): + return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) + if is_namedtuple(obj): + return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) if _has_h5: import h5py @@ -9942,7 +9943,7 @@ def from_dataclass( from tensordict.tensorclass import from_dataclass return from_dataclass(dataclass, auto_batch_size=auto_batch_size) - from dataclasses import fields, is_dataclass + from dataclasses import fields from tensordict import TensorDict diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 1fb7d3049..e1c8e77b4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -46,6 +46,7 @@ CompatibleType, ) from tensordict.utils import ( # @manual=//pytorch/tensordict:_C + _is_dataclass as is_dataclass, _is_json_serializable, _is_tensorclass, _LOCK_ERROR, @@ -450,7 +451,7 @@ def from_dataclass( by default, this method will return a tensorclass instance or type. """ - from dataclasses import asdict, is_dataclass, make_dataclass + from dataclasses import asdict, make_dataclass if isinstance(obj, type): if is_tensorclass(obj): diff --git a/tensordict/utils.py b/tensordict/utils.py index 54e00a174..81ab2fa0c 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -20,6 +20,7 @@ from collections import defaultdict from collections.abc import KeysView from copy import copy +from dataclasses import _FIELDS, GenericAlias from functools import wraps from importlib import import_module from numbers import Number @@ -2813,3 +2814,13 @@ def _mismatch_keys(keys1, keys2): if sub2 is not None: main.append(sub2) raise KeyError(r" ".join(main)) + + +def _is_dataclass(obj): + """Like dataclasses.is_dataclass but compatible with compile.""" + cls = ( + obj + if isinstance(obj, type) and not isinstance(obj, GenericAlias) + else type(obj) + ) + return hasattr(cls, _FIELDS) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 372e1af6a..73d401c03 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -10661,7 +10661,8 @@ def test_non_tensor_call(self): def test_nontensor_dict(self, non_tensor_data): assert ( - TensorDict.from_dict(non_tensor_data.to_dict()) == non_tensor_data + TensorDict.from_dict(non_tensor_data.to_dict(), auto_batch_size=True) + == non_tensor_data ).all() def test_nontensor_tensor(self): @@ -11202,7 +11203,7 @@ def _to_float(td, td_name, tmpdir): td._source = td._source.float() elif td_name in ("td_h5",): td = PersistentTensorDict.from_dict( - td.float().to_dict(), filename=tmpdir + "/file.t" + td.float().to_dict(), filename=tmpdir + "/file.t", auto_batch_size=True ) elif td_name in ("td_params",): td = TensorDictParams(td.data.float()) From 1fd8edaabfa4182b99960f16874bc7f16e6fcc4d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 23 Nov 2024 15:17:10 +0100 Subject: [PATCH 5/6] Update [ghstack-poisoned] --- tensordict/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 3dea5a7d0..6b7afd051 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -124,7 +124,6 @@ class _NoDefault(enum.IntEnum): NO_DEFAULT = _NoDefault.ZERO -assert not NO_DEFAULT T = TypeVar("T", bound="TensorDictBase") @@ -4297,7 +4296,6 @@ def _view_and_pad(tensor): elif k[-1].startswith(""): # NJT/NT always comes before offsets/shapes nt = oldv - assert not v.numel() nt_lengths = None del flat_dict[k] elif k[-1].startswith(""): @@ -9453,7 +9451,6 @@ def _validate_value( if device is not None and value.device != device: if _device_recorder.marked and device.type != "cuda": _device_recorder.record_transfer(device) - assert not non_blocking value = value.to(device, non_blocking=non_blocking) if check_shape: if is_tc is None: From 299aa74f5370e9d04e0c7bac17f90a6dc7e174ae Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 23 Nov 2024 15:33:28 +0100 Subject: [PATCH 6/6] Update [ghstack-poisoned] --- tensordict/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 6b7afd051..6c600b11f 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9873,15 +9873,14 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return cls.from_dict(obj, auto_batch_size=auto_batch_size) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) - if isinstance(obj, tuple): + if is_namedtuple(obj): + return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) return cls.from_tuple(obj, auto_batch_size=auto_batch_size) if isinstance(obj, list): return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) if is_dataclass(obj): return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) - if is_namedtuple(obj): - return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) if _has_h5: import h5py