From 270d7bab427874ab1398a8e1158eb058473061e9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 16:15:37 +0000 Subject: [PATCH] [Refactor] Refactor context managers ghstack-source-id: c16baa83f6e41c4afd6637f3b3739d4e5cf25f1e Pull Request resolved: https://github.com/pytorch/tensordict/pull/1098 --- tensordict/_td.py | 2 +- tensordict/base.py | 8 ++- tensordict/nn/params.py | 6 +-- tensordict/nn/probabilistic.py | 14 +++--- tensordict/nn/utils.py | 39 +++++++------- tensordict/utils.py | 50 +++++++++++------- test/test_nn.py | 60 ++++++++++------------ test/test_tensordict.py | 92 +++++++++++++++++----------------- 8 files changed, 139 insertions(+), 132 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 07a98cdfb..c6b03e1ae 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -609,7 +609,7 @@ def _quick_set(swap_dict, swap_td): _quick_set(_swap, swap_dest) return swap_dest else: - return TensorDict._new_unsafe(_swap, batch_size=[]) + return TensorDict._new_unsafe(_swap, batch_size=torch.Size(())) def __ne__(self, other: object) -> T | bool: if is_tensorclass(other): diff --git a/tensordict/base.py b/tensordict/base.py index 159777e28..91b94e714 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9870,8 +9870,12 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): """ if is_tensor_collection(obj): - if is_non_tensor(obj): - return cls.from_any(obj.data, auto_batch_size=auto_batch_size) + # Conversions from non-tensor data must be done manually + # if is_non_tensor(obj): + # from tensordict.tensorclass import LazyStackedTensorDict + # if isinstance(obj, LazyStackedTensorDict): + # return obj + # return cls.from_any(obj.data, auto_batch_size=auto_batch_size) return obj if isinstance(obj, dict): return cls.from_dict(obj, auto_batch_size=auto_batch_size) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index bc07b7689..07e355746 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -111,10 +111,8 @@ def _maybe_make_param(tensor): def _maybe_make_param_or_buffer(tensor): - if ( - isinstance(tensor, (Tensor, ftdim.Tensor)) - and not isinstance(tensor, (nn.Parameter, Buffer)) - and tensor.dtype in (torch.float, torch.double, torch.half) + if isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance( + tensor, (nn.Parameter, Buffer) ): if not tensor.requires_grad and not is_batchedtensor(tensor): # convert all non-parameters to buffers diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index e13c43f6d..61df65d2b 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -22,7 +22,7 @@ from tensordict.nn.utils import _set_skip_existing_None from tensordict.tensorclass import is_non_tensor from tensordict.tensordict import TensorDictBase -from tensordict.utils import _zip_strict +from tensordict.utils import _ContextManager, _zip_strict from torch import distributions as D, Tensor from torch.utils._contextlib import _DecoratorContextManager @@ -66,12 +66,12 @@ def from_str(cls, type_str: str) -> InteractionType: return cls(type_str.lower()) -_INTERACTION_TYPE: InteractionType | None = None +_interaction_type = _ContextManager() def interaction_type() -> InteractionType | None: """Returns the current sampling type.""" - return _INTERACTION_TYPE + return _interaction_type.get_mode() class set_interaction_type(_DecoratorContextManager): @@ -98,13 +98,11 @@ def clone(self) -> set_interaction_type: return type(self)(self.type) def __enter__(self) -> None: - global _INTERACTION_TYPE - self.prev = _INTERACTION_TYPE - _INTERACTION_TYPE = self.type + self.prev = _interaction_type.get_mode() + _interaction_type.set_mode(self.type) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - global _INTERACTION_TYPE - _INTERACTION_TYPE = self.prev + _interaction_type.set_mode(self.prev) class ProbabilisticTensorDictModule(TensorDictModuleBase): diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 177217b98..f3bcd86c6 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -12,22 +12,24 @@ from typing import Any, Callable import torch -from tensordict.utils import strtobool +from tensordict.utils import _ContextManager, strtobool from torch import nn +from torch.utils._contextlib import _DecoratorContextManager + try: from torch.compiler import is_dynamo_compiling except ImportError: # torch 2.0 from torch._dynamo import is_compiling as is_dynamo_compiling -DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True")) +_dispatch_tdnn_modules = _ContextManager( + default=strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True")) +) __all__ = ["mappings", "inv_softplus", "biased_softplus"] -_SKIP_EXISTING = False - -from torch.utils._contextlib import _DecoratorContextManager +_skip_existing = _ContextManager(default=False) def inv_softplus(bias: float | torch.Tensor) -> float | torch.Tensor: @@ -300,10 +302,9 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any: def __enter__(self) -> None: if self.mode and is_dynamo_compiling(): raise RuntimeError("skip_existing is not compatible with TorchDynamo.") - global _SKIP_EXISTING - self.prev = _SKIP_EXISTING + self.prev = _skip_existing.get_mode() if self.mode is not None: - _SKIP_EXISTING = self.mode + _skip_existing.set_mode(self.mode) elif not self._called: raise RuntimeError( f"It seems you are using {type(self).__name__} as a context manager with ``None`` input. " @@ -311,8 +312,7 @@ def __enter__(self) -> None: ) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - global _SKIP_EXISTING - _SKIP_EXISTING = self.prev + _skip_existing.set_mode(self.prev) class _set_skip_existing_None(set_skip_existing): @@ -353,12 +353,11 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any: return tensordict if is_dynamo_compiling(): return func(_self, tensordict, *args, **kwargs) - global _SKIP_EXISTING - self.prev = _SKIP_EXISTING + self.prev = _skip_existing.get_mode() try: result = func(_self, tensordict, *args, **kwargs) finally: - _SKIP_EXISTING = self.prev + _skip_existing.set_mode(self.prev) return result return wrapper @@ -375,7 +374,7 @@ def clone(self) -> _set_skip_existing_None: def skip_existing(): """Returns whether or not existing entries in a tensordict should be re-computed by a module.""" - return _SKIP_EXISTING + return _skip_existing.get_mode() def _rebuild_buffer(data, requires_grad, backward_hooks): @@ -397,7 +396,7 @@ def _rebuild_buffer(data, requires_grad, backward_hooks): def _dispatch_td_nn_modules(): """Returns ``True`` if @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile.""" - return DISPATCH_TDNN_MODULES + return _dispatch_tdnn_modules.get_mode() class _set_dispatch_td_nn_modules(_DecoratorContextManager): @@ -411,17 +410,15 @@ def clone(self): return type(self)(self.mode) def __enter__(self): - global DISPATCH_TDNN_MODULES # We want to avoid changing global variables because compile puts guards on them - if DISPATCH_TDNN_MODULES != self.mode: - self._saved_mode = DISPATCH_TDNN_MODULES - DISPATCH_TDNN_MODULES = self.mode + if _dispatch_tdnn_modules.get_mode() != self.mode: + self._saved_mode = _dispatch_tdnn_modules + _dispatch_tdnn_modules.set_mode(self.mode) def __exit__(self, exc_type, exc_val, exc_tb): if self._saved_mode is None: return - global DISPATCH_TDNN_MODULES - DISPATCH_TDNN_MODULES = self._saved_mode + _dispatch_tdnn_modules.set_mode(self._saved_mode) # Reproduce StrEnum for python<3.11 diff --git a/tensordict/utils.py b/tensordict/utils.py index 2344da517..0e370856f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -15,10 +15,12 @@ import re import sys +import threading import time import warnings from collections import defaultdict from collections.abc import KeysView +from contextlib import nullcontext from copy import copy from dataclasses import _FIELDS, GenericAlias from functools import wraps @@ -66,12 +68,9 @@ except ImportError: _has_funcdim = False try: - from torch.compiler import assume_constant_result, is_dynamo_compiling + from torch.compiler import assume_constant_result, is_compiling except ImportError: # torch 2.0 - from torch._dynamo import ( - assume_constant_result, - is_compiling as is_dynamo_compiling, - ) + from torch._dynamo import assume_constant_result, is_compiling if TYPE_CHECKING: from tensordict.tensordict import TensorDictBase @@ -862,7 +861,7 @@ def _is_tensorclass(cls: type) -> bool: out = _TENSORCLASS_MEMO.get(cls) if out is None: out = getattr(cls, "_is_tensorclass", False) - if not is_dynamo_compiling(): + if not is_compiling(): _TENSORCLASS_MEMO[cls] = out return out @@ -1118,7 +1117,7 @@ def cache(fun): @wraps(fun) def newfun(_self: "TensorDictBase", *args, **kwargs): - if not _self.is_locked or is_dynamo_compiling(): + if not _self.is_locked or is_compiling(): return fun(_self, *args, **kwargs) cache = _self._cache if cache is None: @@ -1358,7 +1357,7 @@ def _parse_to(*args, **kwargs): num_threads = kwargs.pop("num_threads", None) other = kwargs.pop("other", None) inplace = kwargs.pop("inplace", False) - if not is_dynamo_compiling(): + if not is_compiling(): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( *args, **kwargs ) @@ -1732,7 +1731,7 @@ def _check_keys( is_leaf=_is_leaf_nontensor, ) # TODO: compile doesn't like set() over an arbitrary object - if is_dynamo_compiling(): + if is_compiling(): keys = {k for k in keys} # noqa: C416 else: keys: set[str] = set(keys) @@ -1745,7 +1744,7 @@ def _check_keys( if not strict: keys = keys.intersection(k) else: - if is_dynamo_compiling(): + if is_compiling(): k = {v for v in k} # noqa: C416 else: k = set(k) @@ -2014,7 +2013,7 @@ def _getitem_batch_size(batch_size, index): continue elif isinstance(idx, slice): batch = batch_size[count] - if is_dynamo_compiling(): + if is_compiling(): out.append(len(range(*_slice_indices(idx, batch)))) else: out.append(len(range(*idx.indices(batch)))) @@ -2446,7 +2445,7 @@ def is_non_tensor(data): def _is_non_tensor(cls: type): out = None - is_dynamo = is_dynamo_compiling() + is_dynamo = is_compiling() if not is_dynamo: out = _NON_TENSOR_MEMO.get(cls) if out is None: @@ -2502,7 +2501,7 @@ def new_func(self): def _unravel_key_to_tuple(key): - if not is_dynamo_compiling(): + if not is_compiling(): return _unravel_key_to_tuple_cpp(key) if isinstance(key, str): return (key,) @@ -2523,7 +2522,7 @@ def unravel_key(key): ("a", "b") """ - if not is_dynamo_compiling(): + if not is_compiling(): return unravel_key_cpp(key) if isinstance(key, str): return key @@ -2536,14 +2535,14 @@ def unravel_key(key): def unravel_keys(*keys): """Unravels a sequence of keys.""" - if not is_dynamo_compiling(): + if not is_compiling(): return unravel_keys_cpp(*keys) return tuple(unravel_key(key) for key in keys) def unravel_key_list(keys): """Unravels a list of keys.""" - if not is_dynamo_compiling(): + if not is_compiling(): return unravel_key_list_cpp(keys) return [unravel_key(key) for key in keys] @@ -2823,7 +2822,8 @@ def _is_dataclass(obj): if isinstance(obj, type) and not isinstance(obj, GenericAlias) else type(obj) ) - return hasattr(cls, _FIELDS) + # return hasattr(cls, _FIELDS) + return getattr(cls, _FIELDS, None) is not None def _is_list_tensor_compatible(t) -> Tuple[bool, tuple | None, type | None]: @@ -2857,3 +2857,19 @@ def _is_list_tensor_compatible(t) -> Tuple[bool, tuple | None, type | None]: if len(sizes): return True, (length_t, *list(sizes)[0]), dtype return True, (length_t,), dtype + + +class _ContextManager: + def __init__(self, default=None): + self._mode: Any | None = default + self._lock = threading.Lock() + + def get_mode(self) -> Any | None: + cm = self._lock if not is_compiling() else nullcontext() + with cm: + return self._mode + + def set_mode(self, type: Any | None) -> None: + cm = self._lock if not is_compiling() else nullcontext() + with cm: + self._mode = type diff --git a/test/test_nn.py b/test/test_nn.py index 3134932e9..5c32ac6e3 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -16,7 +16,6 @@ from tensordict._C import unravel_key_list from tensordict.nn import ( dispatch, - probabilistic as nn_probabilistic, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, TensorDictModuleBase, @@ -31,7 +30,11 @@ ) from tensordict.nn.distributions.composite import CompositeDistribution from tensordict.nn.ensemble import EnsembleModule -from tensordict.nn.probabilistic import InteractionType, set_interaction_type +from tensordict.nn.probabilistic import ( + interaction_type, + InteractionType, + set_interaction_type, +) from tensordict.nn.utils import ( _set_dispatch_td_nn_modules, set_skip_existing, @@ -299,10 +302,8 @@ class Data: @pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]]) @pytest.mark.parametrize("lazy", [True, False]) - @pytest.mark.parametrize( - "interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None] - ) - def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys): + @pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None]) + def test_stateful_probabilistic_deprec(self, lazy, it, out_keys): torch.manual_seed(0) param_multiplier = 2 if lazy: @@ -332,10 +333,10 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys): tensordict_module = ProbabilisticTensorDictSequential(net, prob_module) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - with set_interaction_type(interaction_type): + with set_interaction_type(it): with ( pytest.warns(UserWarning, match="deterministic_sample") - if interaction_type in (InteractionType.DETERMINISTIC, None) + if it in (InteractionType.DETERMINISTIC, None) else contextlib.nullcontext() ): tensordict_module(td) @@ -345,12 +346,8 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys): @pytest.mark.parametrize("out_keys", [["low"], ["low1"], [("stuff", "low1")]]) @pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("max_dist", [1.0, 2.0]) - @pytest.mark.parametrize( - "interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None] - ) - def test_stateful_probabilistic_kwargs( - self, lazy, interaction_type, out_keys, max_dist - ): + @pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None]) + def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist): torch.manual_seed(0) if lazy: net = nn.LazyLinear(4) @@ -376,10 +373,10 @@ def test_stateful_probabilistic_kwargs( tensordict_module = ProbabilisticTensorDictSequential(net, prob_module) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - with set_interaction_type(interaction_type): + with set_interaction_type(it): with ( pytest.warns(UserWarning, match="deterministic_sample") - if interaction_type in (None, InteractionType.DETERMINISTIC) + if it in (None, InteractionType.DETERMINISTIC) else contextlib.nullcontext() ): tensordict_module(td) @@ -392,13 +389,13 @@ def test_nontensor(self): in_keys=[], out_keys=["out"], ) - assert tdm(TensorDict({}))["out"] == [1, 2] + assert tdm(TensorDict())["out"] == [1, 2] tdm = TensorDictModule( lambda: "a string!", in_keys=[], out_keys=["out"], ) - assert tdm(TensorDict({}))["out"] == "a string!" + assert tdm(TensorDict())["out"] == "a string!" @pytest.mark.parametrize( "out_keys", @@ -409,10 +406,8 @@ def test_nontensor(self): ], ) @pytest.mark.parametrize("lazy", [True, False]) - @pytest.mark.parametrize( - "interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None] - ) - def test_stateful_probabilistic(self, lazy, interaction_type, out_keys): + @pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None]) + def test_stateful_probabilistic(self, lazy, it, out_keys): torch.manual_seed(0) param_multiplier = 2 if lazy: @@ -441,10 +436,10 @@ def test_stateful_probabilistic(self, lazy, interaction_type, out_keys): ) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - with set_interaction_type(interaction_type): + with set_interaction_type(it): with ( pytest.warns(UserWarning, match="deterministic_sample") - if interaction_type in (None, InteractionType.DETERMINISTIC) + if it in (None, InteractionType.DETERMINISTIC) else contextlib.nullcontext() ): tensordict_module(td) @@ -1115,18 +1110,16 @@ def test_subsequence_weight_update(self): assert torch.allclose(td_module[0].module.weight, sub_seq_1[0].module.weight) -@pytest.mark.parametrize( - "interaction_type", [InteractionType.RANDOM, InteractionType.MODE] -) +@pytest.mark.parametrize("it", [InteractionType.RANDOM, InteractionType.MODE]) class TestSIM: - def test_cm(self, interaction_type): - with set_interaction_type(interaction_type): - assert nn_probabilistic._INTERACTION_TYPE == interaction_type + def test_cm(self, it): + with set_interaction_type(it): + assert interaction_type() == it - def test_dec(self, interaction_type): - @set_interaction_type(interaction_type) + def test_dec(self, it): + @set_interaction_type(it) def dummy(): - assert nn_probabilistic._INTERACTION_TYPE == interaction_type + assert interaction_type() == it dummy() @@ -2022,6 +2015,7 @@ class MyModule(nn.Module): params_m0 = params_m params_m0 = params_m0.apply(lambda x: x.data * 0) assert (params_m0 == 0).all() + assert not (params_m == 0).all() with params_m0.to_module(m): assert (params_m == 0).all() assert not (params_m == 0).all() diff --git a/test/test_tensordict.py b/test/test_tensordict.py index f7fe9a9ff..07eac0ec1 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -814,13 +814,13 @@ def test_expand_with_singleton(self, device): @set_lazy_legacy(True) def test_filling_empty_tensordict(self, device, td_type, update): if td_type == "tensordict": - td = TensorDict({}, batch_size=[16], device=device) + td = TensorDict(batch_size=[16], device=device) elif td_type == "view": - td = TensorDict({}, batch_size=[4, 4], device=device).view(-1) + td = TensorDict(batch_size=[4, 4], device=device).view(-1) elif td_type == "unsqueeze": - td = TensorDict({}, batch_size=[16], device=device).unsqueeze(-1) + td = TensorDict(batch_size=[16], device=device).unsqueeze(-1) elif td_type == "squeeze": - td = TensorDict({}, batch_size=[16, 1], device=device).squeeze(-1) + td = TensorDict(batch_size=[16, 1], device=device).squeeze(-1) elif td_type == "stack": td = LazyStackedTensorDict.lazy_stack( [TensorDict({}, [], device=device) for _ in range(16)], 0 @@ -2591,7 +2591,7 @@ def test_record_stream(self): @pytest.mark.parametrize("device", get_available_devices()) def test_subtensordict_construction(self, device): torch.manual_seed(1) - td = TensorDict({}, batch_size=(4, 5)) + td = TensorDict(batch_size=(4, 5)) val1 = torch.randn(4, 5, 1, device=device) val2 = torch.randn(4, 5, 6, dtype=torch.double, device=device) val1_copy = val1.clone() @@ -2694,7 +2694,7 @@ def test_tensordict_error_messages(self, device): @pytest.mark.parametrize("device", get_available_devices()) def test_tensordict_indexing(self, device): torch.manual_seed(1) - td = TensorDict({}, batch_size=(4, 5)) + td = TensorDict(batch_size=(4, 5)) td.set("key1", torch.randn(4, 5, 1, device=device)) td.set("key2", torch.randn(4, 5, 6, device=device, dtype=torch.double)) @@ -2736,7 +2736,7 @@ def test_tensordict_prealloc_nested(self): N = 3 B = 5 T = 4 - buffer = TensorDict({}, batch_size=[B, N]) + buffer = TensorDict(batch_size=[B, N]) td_0 = TensorDict( { @@ -2777,7 +2777,7 @@ def test_tensordict_prealloc_nested(self): @pytest.mark.parametrize("device", get_available_devices()) def test_tensordict_set(self, device): torch.manual_seed(1) - td = TensorDict({}, batch_size=(4, 5), device=device) + td = TensorDict(batch_size=(4, 5), device=device) td.set("key1", torch.randn(4, 5)) assert td.device == torch.device(device) # by default inplace: @@ -4235,7 +4235,7 @@ def test_flatten_unflatten_bis(self, td_name, device): def test_from_empty(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - new_td = TensorDict({}, batch_size=td.batch_size, device=device) + new_td = TensorDict(batch_size=td.batch_size, device=device) for key, item in td.items(): new_td.set(key, item) assert_allclose_td(td, new_td) @@ -4433,7 +4433,7 @@ def test_items_values_keys(self, td_name, device): items = list(td.items()) # Test td.items() - constructed_td1 = TensorDict({}, batch_size=td.shape) + constructed_td1 = TensorDict(batch_size=td.shape) for key, value in items: constructed_td1.set(key, value) @@ -4443,7 +4443,7 @@ def test_items_values_keys(self, td_name, device): # items = [key, value] should be verified assert len(values) == len(items) assert len(keys) == len(items) - constructed_td2 = TensorDict({}, batch_size=td.shape) + constructed_td2 = TensorDict(batch_size=td.shape) for key, value in list(zip(td.keys(), td.values())): constructed_td2.set(key, value) @@ -4464,7 +4464,7 @@ def test_items_values_keys(self, td_name, device): # Test td.items() # after adding the new element - constructed_td1 = TensorDict({}, batch_size=td.shape) + constructed_td1 = TensorDict(batch_size=td.shape) for key, value in items: constructed_td1.set(key, value) @@ -4476,7 +4476,7 @@ def test_items_values_keys(self, td_name, device): assert len(values) == len(items) assert len(keys) == len(items) - constructed_td2 = TensorDict({}, batch_size=td.shape) + constructed_td2 = TensorDict(batch_size=td.shape) for key, value in list(zip(td.keys(), td.values())): constructed_td2.set(key, value) @@ -9382,14 +9382,14 @@ def run_assertions(): class TestNamedDims(TestTensorDictsBase): def test_all(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tda = td.all(2) assert tda.names == ["a", "b", "d"] tda = td.any(2) assert tda.names == ["a", "b", "d"] def test_apply(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tda = td.apply(lambda x: x + 1) assert tda.names == ["a", "b", "c", "d"] tda = td.apply(lambda x: x.squeeze(2), batch_size=[3, 4, 6]) @@ -9397,15 +9397,15 @@ def test_apply(self): assert tda.names == [None] * 3 def test_cat(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) tdc = torch.cat([td, td], -1) assert tdc.names == [None] * 4 - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) tdc = torch.cat([td, td], -1) assert tdc.names == ["a", "b", "c", "d"] def test_change_batch_size(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) td.batch_size = [3, 4, 1, 6, 1] assert td.names == ["a", "b", "c", "z", None] td.batch_size = [] @@ -9417,7 +9417,7 @@ def test_change_batch_size(self): assert td.names == ["a"] def test_clone(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", "b", "c", "d"] tdc = td.clone() assert tdc.names == ["a", "b", "c", "d"] @@ -9425,14 +9425,14 @@ def test_clone(self): assert tdc.names == ["a", "b", "c", "d"] def test_detach(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) td[""] = torch.zeros(td.shape, requires_grad=True) tdd = td.detach() assert tdd.names == ["a", "b", "c", "d"] def test_error_similar(self): with pytest.raises(ValueError): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "a"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "a"]) with pytest.raises(ValueError): td = TensorDict( {}, @@ -9446,16 +9446,16 @@ def test_error_similar(self): ) td.refine_names("a", "a", ...) with pytest.raises(ValueError): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) td.rename_(a="z") def test_expand(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tde = td.expand(2, 3, 4, 5, 6) assert tde.names == [None, "a", "b", "c", "d"] def test_flatten(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tdf = td.flatten(1, 3) assert tdf.names == ["a", None] tdu = tdf.unflatten(1, (4, 1, 6)) @@ -9470,11 +9470,11 @@ def test_flatten(self): assert tdu.names == [None, None, None, "d"] def test_fullname(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) assert td.names == ["a", "b", "c", "d"] def test_gather(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) idx = torch.randint(6, (3, 4, 1, 18)) tdg = td.gather(dim=-1, index=idx) assert tdg.names == ["a", "b", "c", "d"] @@ -9499,7 +9499,7 @@ def test_h5_td(self): assert td.names == list("abgd") def test_index(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) assert td[0].names == ["b", "c", "d"] assert td[:, 0].names == ["a", "c", "d"] assert td[0, :].names == ["b", "c", "d"] @@ -9519,7 +9519,7 @@ def test_index(self): assert tdbool.ndim == 3 def test_masked_fill(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tdm = td.masked_fill(torch.zeros(3, 4, 1, dtype=torch.bool), 1.0) assert tdm.names == ["a", "b", "c", "d"] @@ -9543,16 +9543,16 @@ def test_memmap_td(self): assert td.clone().names == list("abgd") def test_nested(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) - td["a"] = TensorDict({}, batch_size=[3, 4, 1, 6]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td["a"] = TensorDict(batch_size=[3, 4, 1, 6]) assert td["a"].names == td.names - td["a"] = TensorDict({}, batch_size=[]) + td["a"] = TensorDict() assert td["a"].names == td.names - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=None) - td["a"] = TensorDict({}, batch_size=[3, 4, 1, 6]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=None) + td["a"] = TensorDict(batch_size=[3, 4, 1, 6]) td.names = ["a", "b", None, None] assert td["a"].names == td.names - td.set_("a", TensorDict({}, batch_size=[3, 4, 1, 6])) + td.set_("a", TensorDict(batch_size=[3, 4, 1, 6])) assert td["a"].names == td.names def test_nested_indexing(self): @@ -9602,15 +9602,15 @@ def test_nested_td(self): assert nested_td.contiguous()["my_nested_td"].names == list("abgd") def test_noname(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) assert td.names == [None] * 4 def test_partial_name(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", None, None, "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", None, None, "d"]) assert td.names == ["a", None, None, "d"] def test_partial_set(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", None, None, "d"] assert td.names == ["a", None, None, "d"] td.names = ["a", "b", "c", "d"] @@ -9639,7 +9639,7 @@ def test_permute_td(self): td.names = list("abcd") def test_refine_names(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6]) + td = TensorDict(batch_size=[3, 4, 5, 6]) tdr = td.refine_names(None, None, None, "d") assert tdr.names == [None, None, None, "d"] tdr = tdr.refine_names(None, None, "c", "d") @@ -9654,7 +9654,7 @@ def test_refine_names(self): assert tdr.names == ["a", None, "c", "d"] def test_rename(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", None, None, "d"] td.rename_(a="c") assert td.names == ["c", None, None, "d"] @@ -9670,7 +9670,7 @@ def test_rename(self): assert td2.names == ["w", "x", "y", "z"] def test_select(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tds = td.select() assert tds.names == ["a", "b", "c", "d"] tde = td.exclude() @@ -9707,11 +9707,11 @@ def test_split(self): # assert tdu.is_locked def test_squeeze(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", "b", "c", "d"] tds = td.squeeze(0) assert tds.names == ["a", "b", "c", "d"] - td = TensorDict({}, batch_size=[3, 1, 5, 6], names=None) + td = TensorDict(batch_size=[3, 1, 5, 6], names=None) td.names = ["a", "b", "c", "d"] tds = td.squeeze(1) assert tds.names == ["a", "c", "d"] @@ -9724,7 +9724,7 @@ def test_squeeze_td(self): td.names = list("abcd") def test_stack(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) tds = LazyStackedTensorDict.lazy_stack([td, td], 0) assert tds.names == [None, "a", "b", "c", "d"] tds = LazyStackedTensorDict.lazy_stack([td, td], -1) @@ -9762,7 +9762,7 @@ def test_sub_td(self): td.names = list("abcd") def test_subtd(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) assert td._get_sub_tensordict(0).names == ["b", "c", "d"] assert td._get_sub_tensordict((slice(None), 0)).names == ["a", "c", "d"] assert td._get_sub_tensordict((0, slice(None))).names == ["b", "c", "d"] @@ -9826,14 +9826,14 @@ def test_to(self, device, non_blocking_pin, num_threads, inplace): assert tdt is not td def test_unbind(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) *_, tdu = td.unbind(-1) assert tdu.names == ["a", "b", "c"] *_, tdu = td.unbind(-2) assert tdu.names == ["a", "b", "d"] def test_unsqueeze(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", "b", "c", "d"] tdu = td.unsqueeze(0) assert tdu.names == [None, "a", "b", "c", "d"]