diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 14ffd20f9..0e20aae75 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -7,9 +7,12 @@ import pytest import torch +from packaging import version from tensordict import TensorDict +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) + @pytest.fixture def td(): @@ -50,6 +53,9 @@ def default_device(): @pytest.mark.parametrize("consolidated", [False, True]) +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" +) class TestTo: def test_to(self, benchmark, consolidated, td, default_device): if consolidated: diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index eb98fac8d..3a1ef0ee1 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -6,10 +6,11 @@ import pytest import torch +from packaging import version from tensordict import LazyStackedTensorDict, tensorclass, TensorDict from torch.utils._pytree import tree_map -TORCH_VERSION = torch.__version__ +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @tensorclass @@ -106,7 +107,9 @@ def get_flat_tc(): # Tests runtime of a simple arithmetic op over a highly nested tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_add_one_nested(mode, dict_type, benchmark): @@ -128,7 +131,9 @@ def test_compile_add_one_nested(mode, dict_type, benchmark): # Tests the speed of copying a nested tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_copy_nested(mode, dict_type, benchmark): @@ -150,7 +155,9 @@ def test_compile_copy_nested(mode, dict_type, benchmark): # Tests runtime of a simple arithmetic op over a flat tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) def test_compile_add_one_flat(mode, dict_type, benchmark): @@ -177,7 +184,9 @@ def test_compile_add_one_flat(mode, dict_type, benchmark): benchmark(func, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["eager", "compile"]) @pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) def test_compile_add_self_flat(mode, dict_type, benchmark): @@ -207,7 +216,9 @@ def test_compile_add_self_flat(mode, dict_type, benchmark): # Tests the speed of copying a flat tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_copy_flat(mode, dict_type, benchmark): @@ -235,7 +246,9 @@ def test_compile_copy_flat(mode, dict_type, benchmark): # Tests the speed of assigning entries to an empty tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_assign_and_add(mode, dict_type, benchmark): @@ -264,7 +277,9 @@ def test_compile_assign_and_add(mode, dict_type, benchmark): # Tests the speed of assigning entries to a lazy stacked tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.skipif( torch.cuda.is_available(), reason="max recursion depth error with cuda" ) @@ -285,7 +300,9 @@ def test_compile_assign_and_add_stack(mode, benchmark): # Tests indexing speed -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) @pytest.mark.parametrize("index_type", ["tensor", "slice", "int"]) diff --git a/benchmarks/compile/tensordict_nn_test.py b/benchmarks/compile/tensordict_nn_test.py index 94110162a..7828c29f6 100644 --- a/benchmarks/compile/tensordict_nn_test.py +++ b/benchmarks/compile/tensordict_nn_test.py @@ -9,13 +9,15 @@ import pytest import torch + +from packaging import version from tensordict import TensorDict, TensorDictParams from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq -sys.setrecursionlimit(10000) +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) -TORCH_VERSION = torch.__version__ +sys.setrecursionlimit(10000) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -49,7 +51,9 @@ def mlp(device, depth=2, num_cells=32, feature_dim=3): ) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_mod_add(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -64,7 +68,9 @@ def test_mod_add(mode, benchmark): benchmark(module, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_mod_wrap(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -80,7 +86,9 @@ def test_mod_wrap(mode, benchmark): benchmark(module, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_mod_wrap_and_backward(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -104,7 +112,9 @@ def module_exec(td): benchmark(module_exec, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_seq_add(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -129,7 +139,9 @@ def delhidden(td): benchmark(module, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_seq_wrap(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -161,7 +173,9 @@ def delhidden(td): benchmark(module, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.slow @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_seq_wrap_and_backward(mode, benchmark): @@ -201,7 +215,9 @@ def module_exec(td): benchmark(module_exec, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) @pytest.mark.parametrize("functional", [False, True]) def test_func_call_runtime(mode, functional, benchmark): @@ -272,7 +288,9 @@ def call(x, td): benchmark(call, x) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.slow @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) @pytest.mark.parametrize( @@ -354,7 +372,9 @@ def call(x, td): benchmark(call_vmap, x, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.slow @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) @pytest.mark.parametrize("plain_decorator", [None, False, True]) diff --git a/tensordict/_td.py b/tensordict/_td.py index f56e25052..4387839b5 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool: return all([value.is_contiguous() for _, value in self.items()]) def _clone(self, recurse: bool = True) -> T: + if recurse and self.device is not None: + return self._clone_recurse() + result = TensorDict._new_unsafe( source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, diff --git a/tensordict/base.py b/tensordict/base.py index 5f54abb0b..48b306520 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8123,6 +8123,52 @@ def cosh_(self) -> T: torch._foreach_cosh_(self._values_list(True, True)) return self + def _clone_recurse(self) -> TensorDictBase: # noqa: D417 + keys, vals = self._items_list(True, True) + foreach_vals = {} + iter_vals = {} + for key, val in zip(keys, vals): + if ( + type(val) is torch.Tensor + and not val.requires_grad + and val.dtype not in (torch.bool,) + ): + foreach_vals[key] = val + else: + iter_vals[key] = val + if foreach_vals: + foreach_vals = dict( + _zip_strict( + foreach_vals.keys(), + torch._foreach_add(tuple(foreach_vals.values()), 0), + ) + ) + if iter_vals: + iter_vals = dict( + _zip_strict( + iter_vals.keys(), + ( + val.clone() if hasattr(val, "clone") else val + for val in iter_vals.values() + ), + ) + ) + + items = foreach_vals + items.update(iter_vals) + result = self._fast_apply( + lambda name, val: items.pop(name, None), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=False, + filter_empty=True, + default=None, + ) + if items: + result.update(items) + return result + def add( self, other: TensorDictBase | torch.Tensor, diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index e99236b48..c3fcffe5f 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -267,6 +267,8 @@ def _call( "The output of the function must be a tensordict, a tensorclass or None. Got " f"type(out)={type(out)}." ) + if is_tensor_collection(out): + out.lock_() self._out = out self.counter += 1 if self._out_matches_in: @@ -302,14 +304,15 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): torch._foreach_copy_(dests, srcs) torch.cuda.synchronize() self.graph.replay() - if self._return_unchanged == "clone": - result = self._out.clone() - elif self._return_unchanged: + if self._return_unchanged: result = self._out else: - result = tree_map( - lambda x: x.detach().clone() if x is not None else x, - self._out, + result = tree_unflatten( + [ + out.clone() if hasattr(out, "clone") else out + for out in self._out + ], + self._out_struct, ) return result @@ -340,7 +343,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): out = self.module(*self._args, **self._kwargs) - self._out = out + self._out, self._out_struct = tree_flatten(out) self.counter += 1 # Check that there is not intersection between the indentity of inputs and outputs, otherwise warn # user. @@ -356,11 +359,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): f"and the identity between input and output will not match anymore. " f"Make sure you don't rely on input-output identity further in the code." ) - if isinstance(self._out, torch.Tensor) or self._out is None: - self._return_unchanged = ( - "clone" if self._out is not None else True - ) + if not self._out: + self._return_unchanged = True else: + self._out = [ + out.lock_() if is_tensor_collection(out) else out + for out in self._out + ] self._return_unchanged = False return this_out diff --git a/tensordict/utils.py b/tensordict/utils.py index 7d5c0a624..280b224a0 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2531,6 +2531,7 @@ def _check_inbuild(): else: def _zip_strict(*iterables): + iterables = tuple(tuple(it) for it in iterables) lengths = {len(it) for it in iterables} if len(lengths) > 1: raise ValueError("lengths of iterables differ.") diff --git a/test/test_compile.py b/test/test_compile.py index ff4e79f38..66ad901e4 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -33,11 +33,11 @@ from torch.utils._pytree import SUPPORTED_NODES, tree_map -TORCH_VERSION = version.parse(torch.__version__).base_version +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) _has_onnx = importlib.util.find_spec("onnxruntime", None) is not None -_v2_5 = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse("2.5.0") +_v2_5 = TORCH_VERSION >= version.parse("2.5.0") def test_vmap_compile(): @@ -53,7 +53,9 @@ def func(x, y): funcv_c(x, y) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestTD: def test_tensor_output(self, mode): @@ -340,7 +342,9 @@ class MyClass: c: Any = None -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestTC: def test_tc_tensor_output(self, mode): @@ -579,7 +583,9 @@ def locked_op(tc): assert (tc_op == tc_op_c).all() -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" +) @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestNN: def test_func(self, mode): @@ -657,7 +663,9 @@ def test_dispatch_tensor(self, mode): torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y)) -@pytest.mark.skipif(not (TORCH_VERSION > "2.4.0"), reason="requires torch>2.4") +@pytest.mark.skipif( + TORCH_VERSION <= version.parse("2.4.0"), reason="requires torch>2.4" +) @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestFunctional: def test_functional_error(self, mode): @@ -695,7 +703,9 @@ def call(x, td): # in-place modif raises an error even if fullgraph=False @pytest.mark.parametrize("modif_param", [False]) - @pytest.mark.skipif(not (TORCH_VERSION > "2.5.0"), reason="requires torch>2.5") + @pytest.mark.skipif( + TORCH_VERSION <= version.parse("2.5.0"), reason="requires torch>2.5" + ) def test_functional(self, modif_param, mode): # TODO: UNTESTED @@ -757,7 +767,9 @@ def call(x, td): assert (td_zero == 0).all() # in-place modif raises an error even if fullgraph=False - @pytest.mark.skipif(not (TORCH_VERSION > "2.5.0"), reason="requires torch>2.5") + @pytest.mark.skipif( + TORCH_VERSION <= version.parse("2.5.0"), reason="requires torch>2.5" + ) def test_vmap_functional(self, mode): module = torch.nn.Sequential( torch.nn.Linear(3, 4), @@ -883,7 +895,9 @@ def to_numpy(tensor): ) -@pytest.mark.skipif(TORCH_VERSION <= "2.4.1", reason="requires torch>=2.5") +@pytest.mark.skipif( + TORCH_VERSION <= version.parse("2.4.1"), reason="requires torch>=2.5" +) @pytest.mark.parametrize("compiled", [False, True]) class TestCudaGraphs: @pytest.fixture(scope="class", autouse=True) diff --git a/test/test_distributed.py b/test/test_distributed.py index 64e8a5616..2a30b1593 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -108,9 +108,7 @@ def test_fsdp_module(self, tmpdir): # not using TorchVersion to make the comparison work with dev -TORCH_VERSION = version.parse( - ".".join(map(str, version.parse(torch.__version__).release)) -) +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @pytest.mark.skipif( diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 099d94b25..e67b48416 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -92,11 +92,11 @@ _has_h5py = True except ImportError: _has_h5py = False -TORCH_VERSION = version.parse(torch.__version__).base_version +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) _has_onnx = importlib.util.find_spec("onnxruntime", None) is not None -_v2_5 = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse("2.5.0") +_v2_5 = TORCH_VERSION >= version.parse("2.5.0") _IS_OSX = platform.system() == "Darwin" _IS_WINDOWS = sys.platform == "win32" @@ -7913,7 +7913,9 @@ def check_id(a, b): torch.utils._pytree.tree_map(check_id, td_c._consolidated, tdload._consolidated) assert tdload.is_consolidated() - @pytest.mark.skipif(not _v2_5, reason="v2.5 required for this test") + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.6.0"), reason="v2.6 required for this test" + ) @pytest.mark.parametrize("device", [None, *get_available_devices()]) @pytest.mark.parametrize("use_file", [False, True]) @pytest.mark.parametrize("num_threads", [0, 1, 4])