Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 1, 2024
2 parents 839d0ba + 2292659 commit f2f1fc1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
10 changes: 9 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10666,7 +10666,13 @@ def to(self: T, *, other: T, non_blocking: bool = ...) -> T: ...
def to(self: T, *, batch_size: torch.Size) -> T: ...

def _to_cuda_with_pin_mem(
self, *, num_threads, device="cuda", non_blocking=None, to: Callable
self,
*,
num_threads,
device="cuda",
non_blocking=None,
to: Callable,
inplace: bool = False,
):
if self.is_empty():
return self.to(device)
Expand Down Expand Up @@ -10701,6 +10707,8 @@ def _to_cuda_with_pin_mem(
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
device=device,
out=self if inplace else None,
checked=True,
)
return result

Expand Down
4 changes: 2 additions & 2 deletions tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ T = TypeVar("T", bound="TensorDictBase")
class TensorClass:
def __init__(
self,
# *args,
*args,
batch_size: Sequence[int] | torch.Size | int | None = None,
device: DeviceType | None = None,
names: Sequence[str] | None = None,
non_blocking: bool | None = None,
lock: bool = False,
# **kwargs,
**kwargs,
) -> None: ...
@property
def is_meta(self) -> bool: ...
Expand Down

0 comments on commit f2f1fc1

Please sign in to comment.