diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index c093079d2..7149c196d 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -2279,7 +2279,7 @@ def _cast_reduction( except Exception: raise RuntimeError( f"{reduction_name} requires this object to be cast to a regular TensorDict. " - f"If you need {type(self)} to support {reduction_name}, help us by filing an issue" + f"If you need {type(self).__name__} to support {reduction_name}, help us by filing an issue" f" on github!" ) return td._cast_reduction( @@ -3253,7 +3253,7 @@ def _has_names(self): def _erase_names(self): raise RuntimeError( - f"Cannot erase names of a {type(self)}. " + f"Cannot erase names of a {type(self).__name__}. " f"Erase source TensorDict's names instead." ) @@ -3379,7 +3379,7 @@ def _stack_onto_( dim: int, ) -> T: raise RuntimeError( - f"stacking tensordicts is not allowed for type {type(self)}" + f"stacking tensordicts is not allowed for type {type(self).__name__}" f"consider calling 'to_tensordict()` first" ) @@ -3480,9 +3480,13 @@ def to(self, *args, **kwargs) -> T: batch_size, pin_memory, num_threads, + inplace, ) = _parse_to(*args, **kwargs) + if inplace: + raise TypeError(f"Cannot use inplace=True with {type(self).__name__}.to().") + if batch_size is not None: - raise TypeError(f"Cannot pass batch-size to a {type(self)}.") + raise TypeError(f"Cannot pass batch-size to {type(self).__name__}.to().") result = self if device is not None and dtype is None and device == self.device: @@ -3757,7 +3761,7 @@ def _cast_reduction( except Exception: raise RuntimeError( f"{reduction_name} requires this object to be cast to a regular TensorDict. " - f"If you need {type(self)} to support {reduction_name}, help us by filing an issue" + f"If you need {type(self).__name__} to support {reduction_name}, help us by filing an issue" f" on github!" ) return td._cast_reduction( diff --git a/tensordict/_td.py b/tensordict/_td.py index fd95707fa..9b084fe3b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1152,7 +1152,7 @@ def _multithread_rebuild( ) -> None: if constructor_kwargs: raise RuntimeError( - f"constructor_kwargs not supported for class {type(self)}." + f"constructor_kwargs not supported for class {type(self).__name__}." ) # Rebuilds a tensordict from the futures of its leaves if inplace: @@ -1201,7 +1201,7 @@ def setter( return result.set(key, item_trsf, inplace=inplace) - elif isinstance(result, TensorDict) and checked and (inplace is not True): + elif checked and isinstance(result, TensorDict) and (inplace is not True): def setter( item_trsf, @@ -1329,9 +1329,18 @@ def _apply_nest( "batch_size and out.batch_size must be equal when both are provided." ) if device is not NO_DEFAULT and device != out.device: - raise RuntimeError( - "device and out.device must be equal when both are provided." - ) + if checked: + raise RuntimeError( + f"device and out.device must be equal when both are provided. Got device={device} and out.device={out.device}." + ) + else: + device = torch.device(device) + out._device = device + for node in out.values(True, True, is_leaf=_is_tensor_collection): + if is_tensorclass(node): + node._tensordict._device = device + else: + node._device = device else: def make_result(names=names, batch_size=batch_size): @@ -3594,9 +3603,13 @@ def to(self, *args, **kwargs: Any) -> T: batch_size, pin_memory, num_threads, + inplace, ) = _parse_to(*args, **kwargs) result = self - + if inplace: + raise TypeError( + "Cannot send a _SubTensorDict instance to device/dtype inplace." + ) if device is not None and dtype is None and device == self.device: return result return self.to_tensordict().to(*args, **kwargs) @@ -4093,7 +4106,7 @@ def _cast_reduction( except Exception: raise RuntimeError( f"{reduction_name} requires this object to be cast to a regular TensorDict. " - f"If you need {type(self)} to support {reduction_name}, help us by filing an issue" + f"If you need {type(self).__name__} to support {reduction_name}, help us by filing an issue" f" on github!" ) return td._cast_reduction( diff --git a/tensordict/base.py b/tensordict/base.py index 79ad2cfaf..b72ce7fbc 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10650,6 +10650,7 @@ def to( device: Optional[Union[int, device]] = ..., dtype: Optional[Union[torch.device, str]] = ..., non_blocking: bool = ..., + inplace: bool = False, ) -> T: ... @overload @@ -10665,7 +10666,13 @@ def to(self: T, *, other: T, non_blocking: bool = ...) -> T: ... def to(self: T, *, batch_size: torch.Size) -> T: ... def _to_cuda_with_pin_mem( - self, *, num_threads, device="cuda", non_blocking=None, to: Callable + self, + *, + num_threads, + device="cuda", + non_blocking=None, + to: Callable, + inplace: bool = False, ): if self.is_empty(): return self.to(device) @@ -10700,6 +10707,8 @@ def _to_cuda_with_pin_mem( is_leaf=_NESTED_TENSORS_AS_LISTS, propagate_lock=True, device=device, + out=self if inplace else None, + checked=True, ) return result @@ -10751,6 +10760,9 @@ def to(self, *args, **kwargs) -> T: ``max(1, torch.get_num_threads())`` threads will be spawn. ``num_threads=0`` will cancel any multithreading for the `pin_memory()` calls. + inplace (bool, optional): if ``True``, the data will be written in-place in the same tensordict. + This can be significantly faster whenever building a tensordict is CPU-overhead bound. + Defaults to ``False``. Returns: a new tensordict instance if the device differs from the tensordict @@ -10779,6 +10791,7 @@ def to(self, *args, **kwargs) -> T: batch_size, non_blocking_pin, num_threads, + inplace, ) = _parse_to(*args, **kwargs) result = self @@ -10791,6 +10804,7 @@ def to(self, *args, **kwargs) -> T: pin_memory=non_blocking_pin, num_threads=num_threads, non_blocking=non_blocking, + inplace=inplace, ) if non_blocking is None: @@ -10822,11 +10836,13 @@ def to(tensor): if num_threads is None: num_threads = max(1, torch.get_num_threads() // 2) result = self._to_cuda_with_pin_mem( - num_threads=num_threads, to=to, device=device + num_threads=num_threads, to=to, device=device, inplace=inplace ) else: apply_kwargs["device"] = device if device is not None else self.device apply_kwargs["batch_size"] = batch_size + apply_kwargs["out"] = self if inplace else None + apply_kwargs["checked"] = False if non_blocking_pin: def to_pinmem(tensor, _to=to): @@ -10848,7 +10864,9 @@ def to_pinmem(tensor, _to=to): self._sync_all() return result - def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking): + def _to_consolidated( + self, *, device, pin_memory, num_threads, non_blocking, inplace + ): if num_threads is None: # unspecified num_threads should mean 0 num_threads = 0 @@ -10911,8 +10929,17 @@ def set_(x): storage_offset=storage_offset, ) + if inplace: + out = self + else: + out = None + result = self._fast_apply( - set_, device=torch.device(device), num_threads=num_threads + set_, + device=torch.device(device), + num_threads=num_threads, + out=out, + checked=True, ) result._consolidated = {"storage": storage_cast} if "metadata" in self._consolidated: diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 8b7ee49b3..d5f59110a 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -476,7 +476,7 @@ def keys( ) -> _PersistentTDKeysView: if is_leaf not in (None, _default_is_leaf, _is_leaf_nontensor): raise ValueError( - f"is_leaf {is_leaf} is not supported within tensordicts of type {type(self)}." + f"is_leaf {is_leaf} is not supported within tensordicts of type {type(self).__name__}." ) return _PersistentTDKeysView( tensordict=self, @@ -1026,7 +1026,11 @@ def to(self, *args, **kwargs: Any) -> PersistentTensorDict: batch_size, non_blocking_pin, num_threads, + inplace, ) = _parse_to(*args, **kwargs) + if inplace: + raise TypeError(f"Cannot use inplace=True with {type(self).__name__}.to().") + if non_blocking_pin: raise RuntimeError( f"Cannot use non_blocking_pin=True {type(self).__name__}.to(). Call " @@ -1181,7 +1185,7 @@ def _convert_inplace(self, inplace, key): def _set_non_tensor(self, key: NestedKey, value: Any): raise NotImplementedError( - f"set_non_tensor is not compatible with the tensordict type {type(self)}." + f"set_non_tensor is not compatible with the tensordict type {type(self).__name__}." ) def _set_str( diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 37abe37c1..00fdc14f8 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1707,7 +1707,7 @@ def _set_str( ): if is_non_tensor(self): if key != "data": - raise KeyError(f"only 'data' keys are supported for {type(self)}.") + raise KeyError(f"only 'data' keys are supported for {type(self).__name__}.") while isinstance(value, (NonTensorData, NonTensorStack)): value = value.data self._non_tensordict[key] = value @@ -1737,7 +1737,7 @@ def _set_at_str( ): if is_non_tensor(self): if key != "data": - raise KeyError(f"only 'data' keys are supported for {type(self)}.") + raise KeyError(f"only 'data' keys are supported for {type(self).__name__}.") while isinstance(value, (NonTensorData, NonTensorStack)): value = value.data self._non_tensordict[key] = value diff --git a/tensordict/utils.py b/tensordict/utils.py index 81f0ecb00..b7dad78e4 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1356,6 +1356,7 @@ def _parse_to(*args, **kwargs): non_blocking_pin = kwargs.pop("non_blocking_pin", False) num_threads = kwargs.pop("num_threads", None) other = kwargs.pop("other", None) + inplace = kwargs.pop("inplace", False) if not is_dynamo_compiling(): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( *args, **kwargs @@ -1397,6 +1398,7 @@ def _parse_to(*args, **kwargs): batch_size, non_blocking_pin, num_threads, + inplace, ) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 9b332c379..8879f0e68 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -34,6 +34,11 @@ def get_available_devices(): devices += [torch.device(f"cuda:{i}")] if i == 1: break + # if torch.backends.mps.is_available(): + # for i in range(torch.mps.device_count()): + # devices += [torch.device(f"mps:{i}")] + # if i == 1: + # break return devices diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index df6d41f68..469454ac0 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -2056,8 +2056,26 @@ def test_split(self): def test_to(self): td = self.get_nested() - td = td.to("cpu:1") - assert isinstance(td.get("c")[0], self.TensorClass) + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu:1") + td_device = td.to(device) + assert isinstance(td_device.get("c")[0], self.TensorClass) + assert td_device is not td + assert td_device.device == device + + td_device = td.to(device, inplace=True) + assert td_device is td + assert td_device.device == device + + td_cpu = td_device.to("cpu", inplace=True) + assert td_cpu.device == torch.device("cpu") + + td_double = td.to(torch.float64, inplace=True) + assert td_double is td + assert td_double.dtype == torch.double + assert td_double.device == torch.device("cpu") def test_decorator(): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 5409fd633..a100590bc 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2583,11 +2583,21 @@ def test_tensordict_device(self, device): assert tensordict["a"].device == device assert tensordict["b"].device == device - tensordict = TensorDict({"a": torch.randn(3, 4)}, []) + tensordict = TensorDict( + {"a": torch.randn(3, 4), "b": {"c": torch.randn(3, 4)}}, [] + ) tensordict = tensordict.to(device) assert tensordict.device == device assert tensordict["a"].device == device + tensordict_cpu = tensordict.to("cpu", inplace=True) + assert tensordict_cpu.device == torch.device("cpu") + for v in tensordict_cpu.values(True, True): + assert v.device == torch.device("cpu") + assert tensordict_cpu is tensordict + assert tensordict_cpu["b"] is tensordict["b"] + assert tensordict_cpu["b"].device == torch.device("cpu") + @pytest.mark.skipif( torch.cuda.device_count() == 0, reason="No cuda device detected" ) @@ -3607,7 +3617,7 @@ def test_cast_to(self, td_name, device): "permute_td", "nested_stacked_td", ): - with pytest.raises(TypeError, match="Cannot pass batch-size to a "): + with pytest.raises(TypeError, match="Cannot pass batch-size to "): td_dtype_device = td.to( torch.device("cpu:1"), torch.int, batch_size=torch.Size([]) ) @@ -3626,7 +3636,7 @@ def test_cast_to(self, td_name, device): "permute_td", "nested_stacked_td", ): - with pytest.raises(TypeError, match="Cannot pass batch-size to a "): + with pytest.raises(TypeError, match="Cannot pass batch-size to "): td.to(batch_size=torch.Size([])) else: td_batchsize = td.to(batch_size=torch.Size([])) @@ -6357,6 +6367,38 @@ def test_tensordict_set_dict_value(self, td_name, device): with pytest.raises(KeyError, match=err_msg): td.set_("smartypants", np.ones(shape=(4, 3, 2, 1, 5))) + def test_to_device_dtype_inplace(self, td_name, device): + td = getattr(self, td_name)(device) + if torch.cuda.is_available(): + dest = torch.device("cuda:0") + elif torch.mps.is_available(): + dest = torch.device("mps:0") + else: + dest = torch.device("cpu") + + if td_name in ("sub_td", "sub_td2"): + cm = pytest.raises( + TypeError, + match="Cannot send a _SubTensorDict instance to device/dtype inplace", + ) + elif td_name in ("permute_td", "unsqueezed_td", "squeezed_td", "td_h5"): + cm = pytest.raises(TypeError, match="Cannot use inplace=True with") + elif td.is_locked: + cm = pytest.raises(RuntimeError, match="Cannot modify locked TensorDict.") + else: + cm = contextlib.nullcontext() + with cm: + td.to(torch.float32, inplace=True) + assert td.dtype == torch.float32, td + + with cm: + td.to(dest, inplace=True) + assert td.device == dest + for v in td.values( + True, True, is_leaf=tensordict_base._is_tensor_collection + ): + assert v.device == dest + def test_to_dict_nested(self, td_name, device): def recursive_checker(cur_dict): for _, value in cur_dict.items(): @@ -9612,14 +9654,27 @@ def test_subtd(self): "non_blocking_pin", [False] if not torch.cuda.is_available() else [False, True] ) @pytest.mark.parametrize("num_threads", [0, 1, 4, None]) - def test_to(self, device, non_blocking_pin, num_threads): + @pytest.mark.parametrize("inplace", [True, False]) + def test_to(self, device, non_blocking_pin, num_threads, inplace): td = TensorDict( {"": TensorDict({}, [3, 4, 1, 6])}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"], ) - tdt = td.to(device, non_blocking_pin=non_blocking_pin, num_threads=num_threads) + tdt = td.to( + device, + non_blocking_pin=non_blocking_pin, + num_threads=num_threads, + inplace=inplace, + ) assert tdt.names == ["a", "b", "c", "d"] + assert tdt.device == device + for v in tdt.values(True, True): + assert v.device == device + if inplace: + assert tdt is td + else: + assert tdt is not td def test_unbind(self): td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])