diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index b08298dc1..227aa8106 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -4,26 +4,40 @@ # LICENSE file in the root directory of this source tree. import argparse +import time +from typing import Any import pytest import torch from packaging import version -from tensordict import TensorDict +from tensordict import tensorclass, TensorDict +from tensordict.utils import logger as tensordict_logger TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) -@pytest.fixture -def td(): - return TensorDict( - { - str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)} - for i in range(16) - }, - batch_size=[16], - device="cpu", - ) +@tensorclass +class NJT: + _values: torch.Tensor + _offsets: torch.Tensor + _lengths: torch.Tensor + njt_shape: Any = None + + @classmethod + def from_njt(cls, njt_tensor): + return cls( + _values=njt_tensor._values, + _offsets=njt_tensor._offsets, + _lengths=njt_tensor._lengths, + njt_shape=njt_tensor.size(0), + ) + + +@pytest.fixture(autouse=True, scope="function") +def empty_compiler_cache(): + torch.compiler.reset() + yield def _make_njt(): @@ -34,14 +48,29 @@ def _make_njt(): ) -@pytest.fixture -def njt_td(): +def _njt_td(): return TensorDict( - {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, + # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, + {str(i): _make_njt() for i in range(8)}, device="cpu", ) +@pytest.fixture +def njt_td(): + return _njt_td() + + +@pytest.fixture +def td(): + njtd = _njt_td() + for k0, v0 in njtd.items(): + njtd[k0] = NJT.from_njt(v0) + # for k1, v1 in v0.items(): + # njtd[k0, k1] = NJT.from_njt(v1) + return njtd + + @pytest.fixture def default_device(): if torch.cuda.is_available(): @@ -52,22 +81,142 @@ def default_device(): pytest.skip("CUDA/MPS is not available") -@pytest.mark.parametrize("consolidated", [False, True]) +@pytest.mark.parametrize( + "compile_mode,num_threads", + [ + [False, None], + # [False, 4], + # [False, 16], + ["default", None], + ["reduce-overhead", None], + ], +) +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" +) +class TestConsolidate: + def test_consolidate(self, benchmark, td, compile_mode, num_threads): + tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + + def consolidate(td, num_threads): + return td.consolidate(num_threads=num_threads) + + if compile_mode: + consolidate = torch.compile( + consolidate, mode=compile_mode, dynamic=True, fullgraph=True + ) + + t0 = time.time() + consolidate(td, num_threads=num_threads) + elapsed = time.time() - t0 + tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec") + + for _ in range(3): + consolidate(td, num_threads=num_threads) + + benchmark(consolidate, td, num_threads) + + def test_consolidate_njt(self, benchmark, njt_td, compile_mode, num_threads): + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") + + def consolidate(td, num_threads): + return td.consolidate(num_threads=num_threads) + + if compile_mode: + pytest.skip( + "Compiling NJTs consolidation currently triggers a RuntimeError." + ) + # consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True) + + for _ in range(3): + consolidate(njt_td, num_threads=num_threads) + + benchmark(consolidate, njt_td, num_threads) + + +@pytest.mark.parametrize( + "consolidated,compile_mode,num_threads", + [ + [False, False, None], + [True, False, None], + ["within", False, None], + # [True, False, 4], + # [True, False, 16], + [True, "default", None], + ], +) @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5" ) class TestTo: - def test_to(self, benchmark, consolidated, td, default_device): - if consolidated: - td = td.consolidate() - benchmark(lambda: td.to(default_device)) + def test_to( + self, benchmark, consolidated, td, default_device, compile_mode, num_threads + ): + tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + pin_mem = default_device.type == "cuda" + if consolidated is True: + td = td.consolidate(pin_memory=pin_mem) + + if consolidated == "within": + + def to(td, num_threads): + return td.consolidate(pin_memory=pin_mem).to( + default_device, num_threads=num_threads + ) + + else: - def test_to_njt(self, benchmark, consolidated, njt_td, default_device): - if consolidated: - njt_td = njt_td.consolidate() - benchmark(lambda: njt_td.to(default_device)) + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) + + if compile_mode: + to = torch.compile(to, mode=compile_mode, dynamic=True) + + for _ in range(3): + to(td, num_threads=num_threads) + + benchmark(to, td, num_threads) + + def test_to_njt( + self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads + ): + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") + pin_mem = default_device.type == "cuda" + if consolidated is True: + njt_td = njt_td.consolidate(pin_memory=pin_mem) + + if consolidated == "within": + + def to(td, num_threads): + return td.consolidate(pin_memory=pin_mem).to( + default_device, num_threads=num_threads + ) + + else: + + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) + + if compile_mode: + to = torch.compile(to, mode=compile_mode, dynamic=True) + + for _ in range(3): + to(njt_td, num_threads=num_threads) + + benchmark(to, njt_td, num_threads) if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main( + [ + __file__, + "--capture", + "no", + "--exitfirst", + "--benchmark-group-by", + "func", + "-vvv", + ] + + unknown + ) diff --git a/tensordict/base.py b/tensordict/base.py index 79ad2cfaf..0053b2634 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3860,8 +3860,9 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): pad = 8 - pad else: pad = 0 - flat_size.append(n + pad) - stop = start + flat_size[-1] + flat_size.append(sum([n, pad])) + # Using sum to tell dynamo to use sym_sum + stop = sum([start, flat_size[-1]]) if requires_metadata: metadata_dict["leaves"][key] = ( _DTYPE2STRDTYPE[dtype], @@ -4136,6 +4137,8 @@ def view_old_as_new(v, oldv): return v[: oldv.numel()].view(oldv.shape) return v.view(oldv.shape) + if num_threads is None: + num_threads = 0 if num_threads > 0: def assign( @@ -4241,7 +4244,10 @@ def _view_and_pad(tensor): if v.device != storage.device: v = v.to(storage.device, non_blocking=non_blocking) stride = v.stride() - if (stride and stride[-1] != 1) or v.storage_offset(): + if is_dynamo_compiling(): + if not v.is_contiguous(): + v = v.clone(memory_format=torch.contiguous_format) + elif (stride and stride[-1] != 1) or v.storage_offset(): v = v.clone(memory_format=torch.contiguous_format) v, pad = _view_and_pad(v) items.append(v) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 37abe37c1..6f6f9cc13 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -135,6 +135,7 @@ def __subclasscheck__(self, subclass): "_multithread_rebuild", # rebuild checks if self is a non tensor "_propagate_lock", "_propagate_unlock", + "_reduce_get_metadata", "_values_list", "data_ptr", "dim",