Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 1, 2024
1 parent affa581 commit 71b03ed
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10666,7 +10666,13 @@ def to(self: T, *, other: T, non_blocking: bool = ...) -> T: ...
def to(self: T, *, batch_size: torch.Size) -> T: ...

def _to_cuda_with_pin_mem(
self, *, num_threads, device="cuda", non_blocking=None, to: Callable
self,
*,
num_threads,
device="cuda",
non_blocking=None,
to: Callable,
inplace: bool = False,
):
if self.is_empty():
return self.to(device)
Expand Down Expand Up @@ -10701,6 +10707,8 @@ def _to_cuda_with_pin_mem(
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
device=device,
out=self if inplace else None,
checked=True,
)
return result

Expand Down

0 comments on commit 71b03ed

Please sign in to comment.