Skip to content

Commit

Permalink
[Feature] inplace to method
Browse files Browse the repository at this point in the history
ghstack-source-id: eb25717dd0c9d4581b0ba19aff241e968f8face0
Pull Request resolved: #1066
  • Loading branch information
vmoens committed Nov 1, 2024
1 parent b06de95 commit ab2ad20
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 27 deletions.
14 changes: 9 additions & 5 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,7 +2279,7 @@ def _cast_reduction(
except Exception:
raise RuntimeError(
f"{reduction_name} requires this object to be cast to a regular TensorDict. "
f"If you need {type(self)} to support {reduction_name}, help us by filing an issue"
f"If you need {type(self).__name__} to support {reduction_name}, help us by filing an issue"
f" on github!"
)
return td._cast_reduction(
Expand Down Expand Up @@ -3253,7 +3253,7 @@ def _has_names(self):

def _erase_names(self):
raise RuntimeError(
f"Cannot erase names of a {type(self)}. "
f"Cannot erase names of a {type(self).__name__}. "
f"Erase source TensorDict's names instead."
)

Expand Down Expand Up @@ -3379,7 +3379,7 @@ def _stack_onto_(
dim: int,
) -> T:
raise RuntimeError(
f"stacking tensordicts is not allowed for type {type(self)}"
f"stacking tensordicts is not allowed for type {type(self).__name__}"
f"consider calling 'to_tensordict()` first"
)

Expand Down Expand Up @@ -3480,9 +3480,13 @@ def to(self, *args, **kwargs) -> T:
batch_size,
pin_memory,
num_threads,
inplace,
) = _parse_to(*args, **kwargs)
if inplace:
raise TypeError(f"Cannot use inplace=True with {type(self).__name__}.to().")

if batch_size is not None:
raise TypeError(f"Cannot pass batch-size to a {type(self)}.")
raise TypeError(f"Cannot pass batch-size to {type(self).__name__}.to().")
result = self

if device is not None and dtype is None and device == self.device:
Expand Down Expand Up @@ -3757,7 +3761,7 @@ def _cast_reduction(
except Exception:
raise RuntimeError(
f"{reduction_name} requires this object to be cast to a regular TensorDict. "
f"If you need {type(self)} to support {reduction_name}, help us by filing an issue"
f"If you need {type(self).__name__} to support {reduction_name}, help us by filing an issue"
f" on github!"
)
return td._cast_reduction(
Expand Down
27 changes: 20 additions & 7 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ def _multithread_rebuild(
) -> None:
if constructor_kwargs:
raise RuntimeError(
f"constructor_kwargs not supported for class {type(self)}."
f"constructor_kwargs not supported for class {type(self).__name__}."
)
# Rebuilds a tensordict from the futures of its leaves
if inplace:
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def setter(
return
result.set(key, item_trsf, inplace=inplace)

elif isinstance(result, TensorDict) and checked and (inplace is not True):
elif checked and isinstance(result, TensorDict) and (inplace is not True):

def setter(
item_trsf,
Expand Down Expand Up @@ -1329,9 +1329,18 @@ def _apply_nest(
"batch_size and out.batch_size must be equal when both are provided."
)
if device is not NO_DEFAULT and device != out.device:
raise RuntimeError(
"device and out.device must be equal when both are provided."
)
if checked:
raise RuntimeError(
f"device and out.device must be equal when both are provided. Got device={device} and out.device={out.device}."
)
else:
device = torch.device(device)
out._device = device
for node in out.values(True, True, is_leaf=_is_tensor_collection):
if is_tensorclass(node):
node._tensordict._device = device
else:
node._device = device
else:

def make_result(names=names, batch_size=batch_size):
Expand Down Expand Up @@ -3594,9 +3603,13 @@ def to(self, *args, **kwargs: Any) -> T:
batch_size,
pin_memory,
num_threads,
inplace,
) = _parse_to(*args, **kwargs)
result = self

if inplace:
raise TypeError(
"Cannot send a _SubTensorDict instance to device/dtype inplace."
)
if device is not None and dtype is None and device == self.device:
return result
return self.to_tensordict().to(*args, **kwargs)
Expand Down Expand Up @@ -4093,7 +4106,7 @@ def _cast_reduction(
except Exception:
raise RuntimeError(
f"{reduction_name} requires this object to be cast to a regular TensorDict. "
f"If you need {type(self)} to support {reduction_name}, help us by filing an issue"
f"If you need {type(self).__name__} to support {reduction_name}, help us by filing an issue"
f" on github!"
)
return td._cast_reduction(
Expand Down
35 changes: 31 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10650,6 +10650,7 @@ def to(
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[torch.device, str]] = ...,
non_blocking: bool = ...,
inplace: bool = False,
) -> T: ...

@overload
Expand All @@ -10665,7 +10666,13 @@ def to(self: T, *, other: T, non_blocking: bool = ...) -> T: ...
def to(self: T, *, batch_size: torch.Size) -> T: ...

def _to_cuda_with_pin_mem(
self, *, num_threads, device="cuda", non_blocking=None, to: Callable
self,
*,
num_threads,
device="cuda",
non_blocking=None,
to: Callable,
inplace: bool = False,
):
if self.is_empty():
return self.to(device)
Expand Down Expand Up @@ -10700,6 +10707,8 @@ def _to_cuda_with_pin_mem(
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
device=device,
out=self if inplace else None,
checked=True,
)
return result

Expand Down Expand Up @@ -10751,6 +10760,9 @@ def to(self, *args, **kwargs) -> T:
``max(1, torch.get_num_threads())`` threads will be spawn.
``num_threads=0`` will cancel any
multithreading for the `pin_memory()` calls.
inplace (bool, optional): if ``True``, the data will be written in-place in the same tensordict.
This can be significantly faster whenever building a tensordict is CPU-overhead bound.
Defaults to ``False``.
Returns:
a new tensordict instance if the device differs from the tensordict
Expand Down Expand Up @@ -10779,6 +10791,7 @@ def to(self, *args, **kwargs) -> T:
batch_size,
non_blocking_pin,
num_threads,
inplace,
) = _parse_to(*args, **kwargs)
result = self

Expand All @@ -10791,6 +10804,7 @@ def to(self, *args, **kwargs) -> T:
pin_memory=non_blocking_pin,
num_threads=num_threads,
non_blocking=non_blocking,
inplace=inplace,
)

if non_blocking is None:
Expand Down Expand Up @@ -10822,11 +10836,13 @@ def to(tensor):
if num_threads is None:
num_threads = max(1, torch.get_num_threads() // 2)
result = self._to_cuda_with_pin_mem(
num_threads=num_threads, to=to, device=device
num_threads=num_threads, to=to, device=device, inplace=inplace
)
else:
apply_kwargs["device"] = device if device is not None else self.device
apply_kwargs["batch_size"] = batch_size
apply_kwargs["out"] = self if inplace else None
apply_kwargs["checked"] = False
if non_blocking_pin:

def to_pinmem(tensor, _to=to):
Expand All @@ -10848,7 +10864,9 @@ def to_pinmem(tensor, _to=to):
self._sync_all()
return result

def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
def _to_consolidated(
self, *, device, pin_memory, num_threads, non_blocking, inplace
):
if num_threads is None:
# unspecified num_threads should mean 0
num_threads = 0
Expand Down Expand Up @@ -10911,8 +10929,17 @@ def set_(x):
storage_offset=storage_offset,
)

if inplace:
out = self
else:
out = None

result = self._fast_apply(
set_, device=torch.device(device), num_threads=num_threads
set_,
device=torch.device(device),
num_threads=num_threads,
out=out,
checked=True,
)
result._consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
Expand Down
8 changes: 6 additions & 2 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def keys(
) -> _PersistentTDKeysView:
if is_leaf not in (None, _default_is_leaf, _is_leaf_nontensor):
raise ValueError(
f"is_leaf {is_leaf} is not supported within tensordicts of type {type(self)}."
f"is_leaf {is_leaf} is not supported within tensordicts of type {type(self).__name__}."
)
return _PersistentTDKeysView(
tensordict=self,
Expand Down Expand Up @@ -1026,7 +1026,11 @@ def to(self, *args, **kwargs: Any) -> PersistentTensorDict:
batch_size,
non_blocking_pin,
num_threads,
inplace,
) = _parse_to(*args, **kwargs)
if inplace:
raise TypeError(f"Cannot use inplace=True with {type(self).__name__}.to().")

if non_blocking_pin:
raise RuntimeError(
f"Cannot use non_blocking_pin=True {type(self).__name__}.to(). Call "
Expand Down Expand Up @@ -1181,7 +1185,7 @@ def _convert_inplace(self, inplace, key):

def _set_non_tensor(self, key: NestedKey, value: Any):
raise NotImplementedError(
f"set_non_tensor is not compatible with the tensordict type {type(self)}."
f"set_non_tensor is not compatible with the tensordict type {type(self).__name__}."
)

def _set_str(
Expand Down
4 changes: 2 additions & 2 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,7 @@ def _set_str(
):
if is_non_tensor(self):
if key != "data":
raise KeyError(f"only 'data' keys are supported for {type(self)}.")
raise KeyError(f"only 'data' keys are supported for {type(self).__name__}.")
while isinstance(value, (NonTensorData, NonTensorStack)):
value = value.data
self._non_tensordict[key] = value
Expand Down Expand Up @@ -1737,7 +1737,7 @@ def _set_at_str(
):
if is_non_tensor(self):
if key != "data":
raise KeyError(f"only 'data' keys are supported for {type(self)}.")
raise KeyError(f"only 'data' keys are supported for {type(self).__name__}.")
while isinstance(value, (NonTensorData, NonTensorStack)):
value = value.data
self._non_tensordict[key] = value
Expand Down
2 changes: 2 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,7 @@ def _parse_to(*args, **kwargs):
non_blocking_pin = kwargs.pop("non_blocking_pin", False)
num_threads = kwargs.pop("num_threads", None)
other = kwargs.pop("other", None)
inplace = kwargs.pop("inplace", False)
if not is_dynamo_compiling():
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
Expand Down Expand Up @@ -1397,6 +1398,7 @@ def _parse_to(*args, **kwargs):
batch_size,
non_blocking_pin,
num_threads,
inplace,
)


Expand Down
5 changes: 5 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def get_available_devices():
devices += [torch.device(f"cuda:{i}")]
if i == 1:
break
# if torch.backends.mps.is_available():
# for i in range(torch.mps.device_count()):
# devices += [torch.device(f"mps:{i}")]
# if i == 1:
# break
return devices


Expand Down
22 changes: 20 additions & 2 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,8 +2056,26 @@ def test_split(self):

def test_to(self):
td = self.get_nested()
td = td.to("cpu:1")
assert isinstance(td.get("c")[0], self.TensorClass)
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu:1")
td_device = td.to(device)
assert isinstance(td_device.get("c")[0], self.TensorClass)
assert td_device is not td
assert td_device.device == device

td_device = td.to(device, inplace=True)
assert td_device is td
assert td_device.device == device

td_cpu = td_device.to("cpu", inplace=True)
assert td_cpu.device == torch.device("cpu")

td_double = td.to(torch.float64, inplace=True)
assert td_double is td
assert td_double.dtype == torch.double
assert td_double.device == torch.device("cpu")


def test_decorator():
Expand Down
Loading

0 comments on commit ab2ad20

Please sign in to comment.