diff --git a/tensordict/base.py b/tensordict/base.py index 0b0fd8cb2..c44d1b4c3 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10607,7 +10607,7 @@ def to(tensor): else: apply_kwargs["device"] = device if device is not None else self.device apply_kwargs["batch_size"] = batch_size - apply_kwargs["out"] = self if inplace else None + apply_kwargs["out"] = None apply_kwargs["checked"] = True if non_blocking_pin: diff --git a/tensordict/utils.py b/tensordict/utils.py index fb485573c..e6d62087c 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2779,7 +2779,7 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths): tgt = mb_unwrap_functional_tensor(new_thing) src = mb_unwrap_functional_tensor(ragged_source) tgt.nested_int_memo = src.nested_int_memo - else: + elif new_thing is not None: _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source] return NestedTensor(