diff --git a/tensordict/base.py b/tensordict/base.py index 5ffc217b3..b72ce7fbc 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10666,7 +10666,13 @@ def to(self: T, *, other: T, non_blocking: bool = ...) -> T: ... def to(self: T, *, batch_size: torch.Size) -> T: ... def _to_cuda_with_pin_mem( - self, *, num_threads, device="cuda", non_blocking=None, to: Callable + self, + *, + num_threads, + device="cuda", + non_blocking=None, + to: Callable, + inplace: bool = False, ): if self.is_empty(): return self.to(device) @@ -10701,6 +10707,8 @@ def _to_cuda_with_pin_mem( is_leaf=_NESTED_TENSORS_AS_LISTS, propagate_lock=True, device=device, + out=self if inplace else None, + checked=True, ) return result diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index eec014faa..63cdaa181 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -69,13 +69,13 @@ T = TypeVar("T", bound="TensorDictBase") class TensorClass: def __init__( self, - # *args, + *args, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, names: Sequence[str] | None = None, non_blocking: bool | None = None, lock: bool = False, - # **kwargs, + **kwargs, ) -> None: ... @property def is_meta(self) -> bool: ...