Skip to content

Commit

Permalink
[Performance] Faster clone
Browse files Browse the repository at this point in the history
ghstack-source-id: 14d558692120ea48c40188f4eaaced9c506c0f17
Pull Request resolved: #1043
  • Loading branch information
vmoens committed Oct 16, 2024
1 parent fd400af commit ee49fc7
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 46 deletions.
6 changes: 6 additions & 0 deletions benchmarks/common/h2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

import pytest
import torch
from packaging import version

from tensordict import TensorDict

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@pytest.fixture
def td():
Expand Down Expand Up @@ -50,6 +53,9 @@ def default_device():


@pytest.mark.parametrize("consolidated", [False, True])
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
class TestTo:
def test_to(self, benchmark, consolidated, td, default_device):
if consolidated:
Expand Down
35 changes: 26 additions & 9 deletions benchmarks/compile/compile_td_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import pytest
import torch
from packaging import version
from tensordict import LazyStackedTensorDict, tensorclass, TensorDict
from torch.utils._pytree import tree_map

TORCH_VERSION = torch.__version__
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@tensorclass
Expand Down Expand Up @@ -106,7 +107,9 @@ def get_flat_tc():


# Tests runtime of a simple arithmetic op over a highly nested tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_add_one_nested(mode, dict_type, benchmark):
Expand All @@ -128,7 +131,9 @@ def test_compile_add_one_nested(mode, dict_type, benchmark):


# Tests the speed of copying a nested tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_copy_nested(mode, dict_type, benchmark):
Expand All @@ -150,7 +155,9 @@ def test_compile_copy_nested(mode, dict_type, benchmark):


# Tests runtime of a simple arithmetic op over a flat tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
def test_compile_add_one_flat(mode, dict_type, benchmark):
Expand All @@ -177,7 +184,9 @@ def test_compile_add_one_flat(mode, dict_type, benchmark):
benchmark(func, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
def test_compile_add_self_flat(mode, dict_type, benchmark):
Expand Down Expand Up @@ -207,7 +216,9 @@ def test_compile_add_self_flat(mode, dict_type, benchmark):


# Tests the speed of copying a flat tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_copy_flat(mode, dict_type, benchmark):
Expand Down Expand Up @@ -235,7 +246,9 @@ def test_compile_copy_flat(mode, dict_type, benchmark):


# Tests the speed of assigning entries to an empty tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_assign_and_add(mode, dict_type, benchmark):
Expand Down Expand Up @@ -264,7 +277,9 @@ def test_compile_assign_and_add(mode, dict_type, benchmark):
# Tests the speed of assigning entries to a lazy stacked tensordict


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.skipif(
torch.cuda.is_available(), reason="max recursion depth error with cuda"
)
Expand All @@ -285,7 +300,9 @@ def test_compile_assign_and_add_stack(mode, benchmark):


# Tests indexing speed
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"])
Expand Down
42 changes: 31 additions & 11 deletions benchmarks/compile/tensordict_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

import pytest
import torch

from packaging import version
from tensordict import TensorDict, TensorDictParams

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq

sys.setrecursionlimit(10000)
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

TORCH_VERSION = torch.__version__
sys.setrecursionlimit(10000)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -49,7 +51,9 @@ def mlp(device, depth=2, num_cells=32, feature_dim=3):
)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -64,7 +68,9 @@ def test_mod_add(mode, benchmark):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -80,7 +86,9 @@ def test_mod_wrap(mode, benchmark):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_wrap_and_backward(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -104,7 +112,9 @@ def module_exec(td):
benchmark(module_exec, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -129,7 +139,9 @@ def delhidden(td):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -161,7 +173,9 @@ def delhidden(td):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_wrap_and_backward(mode, benchmark):
Expand Down Expand Up @@ -201,7 +215,9 @@ def module_exec(td):
benchmark(module_exec, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize("functional", [False, True])
def test_func_call_runtime(mode, functional, benchmark):
Expand Down Expand Up @@ -272,7 +288,9 @@ def call(x, td):
benchmark(call, x)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -354,7 +372,9 @@ def call(x, td):
benchmark(call_vmap, x, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize("plain_decorator", [None, False, True])
Expand Down
3 changes: 3 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])

def _clone(self, recurse: bool = True) -> T:
if recurse and self.device is not None:
return self._clone_recurse()

result = TensorDict._new_unsafe(
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
Expand Down
46 changes: 46 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8123,6 +8123,52 @@ def cosh_(self) -> T:
torch._foreach_cosh_(self._values_list(True, True))
return self

def _clone_recurse(self) -> TensorDictBase: # noqa: D417
keys, vals = self._items_list(True, True)
foreach_vals = {}
iter_vals = {}
for key, val in zip(keys, vals):
if (
type(val) is torch.Tensor
and not val.requires_grad
and val.dtype not in (torch.bool,)
):
foreach_vals[key] = val
else:
iter_vals[key] = val
if foreach_vals:
foreach_vals = dict(
_zip_strict(
foreach_vals.keys(),
torch._foreach_add(tuple(foreach_vals.values()), 0),
)
)
if iter_vals:
iter_vals = dict(
_zip_strict(
iter_vals.keys(),
(
val.clone() if hasattr(val, "clone") else val
for val in iter_vals.values()
),
)
)

items = foreach_vals
items.update(iter_vals)
result = self._fast_apply(
lambda name, val: items.pop(name, None),
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=False,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

def add(
self,
other: TensorDictBase | torch.Tensor,
Expand Down
27 changes: 16 additions & 11 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def _call(
"The output of the function must be a tensordict, a tensorclass or None. Got "
f"type(out)={type(out)}."
)
if is_tensor_collection(out):
out.lock_()
self._out = out
self.counter += 1
if self._out_matches_in:
Expand Down Expand Up @@ -302,14 +304,15 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
torch._foreach_copy_(dests, srcs)
torch.cuda.synchronize()
self.graph.replay()
if self._return_unchanged == "clone":
result = self._out.clone()
elif self._return_unchanged:
if self._return_unchanged:
result = self._out
else:
result = tree_map(
lambda x: x.detach().clone() if x is not None else x,
self._out,
result = tree_unflatten(
[
out.clone() if hasattr(out, "clone") else out
for out in self._out
],
self._out_struct,
)
return result

Expand Down Expand Up @@ -340,7 +343,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
self._out = out
self._out, self._out_struct = tree_flatten(out)
self.counter += 1
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
# user.
Expand All @@ -356,11 +359,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
f"and the identity between input and output will not match anymore. "
f"Make sure you don't rely on input-output identity further in the code."
)
if isinstance(self._out, torch.Tensor) or self._out is None:
self._return_unchanged = (
"clone" if self._out is not None else True
)
if not self._out:
self._return_unchanged = True
else:
self._out = [
out.lock_() if is_tensor_collection(out) else out
for out in self._out
]
self._return_unchanged = False
return this_out

Expand Down
1 change: 1 addition & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2531,6 +2531,7 @@ def _check_inbuild():
else:

def _zip_strict(*iterables):
iterables = tuple(tuple(it) for it in iterables)
lengths = {len(it) for it in iterables}
if len(lengths) > 1:
raise ValueError("lengths of iterables differ.")
Expand Down
Loading

0 comments on commit ee49fc7

Please sign in to comment.