Skip to content

Commit

Permalink
[Performance] Faster to
Browse files Browse the repository at this point in the history
ghstack-source-id: 3dfb0b66fae82dc8cf5ef2a14eccb1bec5237ebb
Pull Request resolved: #1073

(cherry picked from commit 6272510)
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent 80dedb3 commit f031bf2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
17 changes: 13 additions & 4 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ def _apply_nest(
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
is_leaf: Callable = None,
is_leaf: Callable | None = None,
out: TensorDictBase | None = None,
**constructor_kwargs,
) -> T | None:
Expand All @@ -1319,9 +1319,18 @@ def _apply_nest(
"batch_size and out.batch_size must be equal when both are provided."
)
if device is not NO_DEFAULT and device != out.device:
raise RuntimeError(
"device and out.device must be equal when both are provided."
)
if not checked:
raise RuntimeError(
f"device and out.device must be equal when both are provided. Got device={device} and out.device={out.device}."
)
else:
device = torch.device(device)
out._device = device
for node in out.values(True, True, is_leaf=_is_tensor_collection):
if is_tensorclass(node):
node._tensordict._device = device
else:
node._device = device
else:

def make_result(names=names, batch_size=batch_size):
Expand Down
16 changes: 15 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10586,6 +10586,8 @@ def to(tensor):
else:
apply_kwargs["device"] = device if device is not None else self.device
apply_kwargs["batch_size"] = batch_size
apply_kwargs["out"] = self if inplace else None
apply_kwargs["checked"] = True
if non_blocking_pin:

def to_pinmem(tensor, _to=to):
Expand All @@ -10595,7 +10597,19 @@ def to_pinmem(tensor, _to=to):
to_pinmem, propagate_lock=True, **apply_kwargs
)
else:
result = result._fast_apply(to, propagate_lock=True, **apply_kwargs)
# result = result._fast_apply(to, propagate_lock=True, **apply_kwargs)
keys, tensors = self._items_list(True, True)
tensors = [to(t) for t in tensors]
items = dict(zip(keys, tensors))
result = self._fast_apply(
lambda name, val: items.get(name, val),
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
**apply_kwargs,
)

if batch_size is not None:
result.batch_size = batch_size
if (
Expand Down

0 comments on commit f031bf2

Please sign in to comment.