Skip to content

Commit

Permalink
[Performance] Faster to
Browse files Browse the repository at this point in the history
ghstack-source-id: 3dfb0b66fae82dc8cf5ef2a14eccb1bec5237ebb
Pull Request resolved: #1073
  • Loading branch information
vmoens committed Nov 5, 2024
1 parent f12d31d commit 6272510
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ def _apply_nest(
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
is_leaf: Callable = None,
is_leaf: Callable | None = None,
out: TensorDictBase | None = None,
**constructor_kwargs,
) -> T | None:
Expand All @@ -1329,7 +1329,7 @@ def _apply_nest(
"batch_size and out.batch_size must be equal when both are provided."
)
if device is not NO_DEFAULT and device != out.device:
if checked:
if not checked:
raise RuntimeError(
f"device and out.device must be equal when both are provided. Got device={device} and out.device={out.device}."
)
Expand Down
16 changes: 14 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10863,7 +10863,7 @@ def to(tensor):
apply_kwargs["device"] = device if device is not None else self.device
apply_kwargs["batch_size"] = batch_size
apply_kwargs["out"] = self if inplace else None
apply_kwargs["checked"] = False
apply_kwargs["checked"] = True
if non_blocking_pin:

def to_pinmem(tensor, _to=to):
Expand All @@ -10873,7 +10873,19 @@ def to_pinmem(tensor, _to=to):
to_pinmem, propagate_lock=True, **apply_kwargs
)
else:
result = result._fast_apply(to, propagate_lock=True, **apply_kwargs)
# result = result._fast_apply(to, propagate_lock=True, **apply_kwargs)
keys, tensors = self._items_list(True, True)
tensors = [to(t) for t in tensors]
items = dict(zip(keys, tensors))
result = self._fast_apply(
lambda name, val: items.get(name, val),
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
**apply_kwargs,
)

if batch_size is not None:
result.batch_size = batch_size
if (
Expand Down

0 comments on commit 6272510

Please sign in to comment.