From e696708865e67f019e5c1d75a2ec1f81ca479a77 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 16 Oct 2024 11:01:44 +0100 Subject: [PATCH] [Feature] TD+NJT to(device) support ghstack-source-id: 5f84ebc2a01e6dab26fe1d68d67bb166a295e885 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1022 --- tensordict/_reductions.py | 6 +- tensordict/base.py | 122 +++++++++++++++++++++++++++----------- tensordict/utils.py | 55 +++++++++++++++-- test/test_tensordict.py | 59 ++++++++++++++---- 4 files changed, 190 insertions(+), 52 deletions(-) diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 2088d764c..7234a42bd 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -99,6 +99,8 @@ def from_metadata(metadata=metadata, prefix=None): value = value[: local_shape.numel()] value = value.view(local_shape) if key.startswith(""): + raise RuntimeError + elif key.startswith(""): nested_values = value nested_lengths = None continue @@ -106,8 +108,10 @@ def from_metadata(metadata=metadata, prefix=None): nested_lengths = value continue elif key.startswith(""): + from torch.nested._internal.nested_tensor import NestedTensor + offsets = value - value = torch.nested.nested_tensor_from_jagged( + value = NestedTensor( nested_values, offsets=offsets, lengths=nested_lengths ) key = key.replace("", "") diff --git a/tensordict/base.py b/tensordict/base.py index 1795504b6..5f54abb0b 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -20,7 +20,7 @@ from collections.abc import MutableMapping from concurrent.futures import Future, ThreadPoolExecutor, wait -from copy import copy, deepcopy +from copy import copy from functools import partial, wraps from pathlib import Path from textwrap import indent @@ -66,6 +66,7 @@ _prefix_last_key, _proc_init, _prune_selected_keys, + _rebuild_njt_from_njt, _set_max_batch_size, _shape, _split_tensordict, @@ -3591,7 +3592,7 @@ def assign( if getattr(value, "is_nested", False): if value.layout is torch.jagged: # Get the values - values = value.values() + values = value._values shape = [v if isinstance(v, int) else -1 for v in values.shape] # Get the offsets offsets = value._offsets @@ -3602,10 +3603,14 @@ def assign( # We will rely on the fact that the writing order is preserved in python dict # (since python 3.7). Later, we will read the NJT then the NJT offset in that order # to do the allocation. - flat_key_values[_prefix_last_key(total_key, "")] = values + flat_key_values[_prefix_last_key(total_key, "")] = value + flat_size.append(0) + flat_key_values[_prefix_last_key(total_key, "")] = ( + values + ) add_single_value( values, - _prefix_last_key(key, ""), + _prefix_last_key(key, ""), metadata_dict, values.dtype, shape, @@ -3811,12 +3816,14 @@ def assign( start, stop, njts, - njts_offsets, - njts_lengths, storage=storage, non_blocking=non_blocking, ): + """Reads a slice of the storage and assigns the resulting tensor in flat_dict.""" # v may need padding + if k[-1].startswith(""): + njts[k] = v + return v_pad = v.view(-1).view(torch.uint8) exp_length = stop - start pad = exp_length - v_pad.numel() @@ -3830,17 +3837,9 @@ def assign( if pad: new_v = new_v[: v.numel()] new_v = new_v.view(shape) - if k[-1].startswith(""): - njts[k] = new_v - elif k[-1].startswith(""): - njts_lengths[k] = new_v - elif k[-1].startswith(""): - njts_offsets[k] = new_v flat_dict[k] = new_v njts = {} - njts_offsets = {} - njts_lengths = {} if num_threads > 1: executor = ThreadPoolExecutor(num_threads) r = [] @@ -3853,8 +3852,6 @@ def assign( start=offsets[i], stop=offsets[i + 1], njts=njts, - njts_offsets=njts_offsets, - njts_lengths=njts_lengths, ) ) if not return_early: @@ -3872,25 +3869,25 @@ def assign( start=offsets[i], stop=offsets[i + 1], njts=njts, - njts_offsets=njts_offsets, - njts_lengths=njts_lengths, ) - for njt_key, njt_val in njts.items(): + for njt_key, njt in njts.items(): + newkey = njt_key[:-1] + (njt_key[-1].replace("", ""),) + njt_key_values = njt_key[:-1] + ( + njt_key[-1].replace("", ""), + ) njt_key_offset = njt_key[:-1] + ( njt_key[-1].replace("", ""), ) njt_key_lengths = njt_key[:-1] + ( njt_key[-1].replace("", ""), ) - val = torch.nested.nested_tensor_from_jagged( - njt_val, - offsets=flat_dict[njt_key_offset], - lengths=flat_dict.get(njt_key_lengths), + val = _rebuild_njt_from_njt( + njt, + values=flat_dict.pop(njt_key_values), + offsets=flat_dict.pop(njt_key_offset), + lengths=flat_dict.pop(njt_key_lengths, None), ) del flat_dict[njt_key] - del flat_dict[njt_key_offset] - flat_dict.pop(njt_key_lengths, None) - newkey = njt_key[:-1] + (njt_key[-1].replace("", ""),) flat_dict[newkey] = val if non_blocking and device.type != "cuda": @@ -3910,6 +3907,8 @@ def _view_and_pad(tensor): items = [] for v in flat_dict.values(): + if v.is_nested: + continue if v.device != storage.device: v = v.to(storage.device, non_blocking=non_blocking) stride = v.stride() @@ -3928,9 +3927,13 @@ def _view_and_pad(tensor): flat_dict[k] = view_old_as_new(v, oldv) elif k[-1].startswith(""): # NJT/NT always comes before offsets/shapes - _nested_values = view_old_as_new(v, oldv) + nt = oldv + assert not v.numel() nt_lengths = None del flat_dict[k] + elif k[-1].startswith(""): + nt_vaues = view_old_as_new(v, oldv) + del flat_dict[k] elif k[-1].startswith(""): nt_lengths = view_old_as_new(v, oldv) del flat_dict[k] @@ -3939,15 +3942,16 @@ def _view_and_pad(tensor): nt_offsets = view_old_as_new(v, oldv) del flat_dict[k] - flat_dict[newk] = torch.nested.nested_tensor_from_jagged( - _nested_values, - offsets=nt_offsets, - lengths=nt_lengths, + val = _rebuild_njt_from_njt( + nt, values=nt_vaues, offsets=nt_offsets, lengths=nt_lengths ) + + flat_dict[newk] = val + # delete the nested value to make sure that if there was an # ordering mismatch we wouldn't be looking at the value key of # another nested tensor. - del _nested_values + del nt, nt_vaues, nt_offsets, nt_lengths else: flat_dict[k] = view_old_as_new(v, oldv) @@ -10459,9 +10463,52 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking): untyped_storage = storage_cast.untyped_storage() def set_(x): + if x.is_nested: + from torch._subclasses.fake_tensor import FakeTensor + from torch._subclasses.functional_tensor import FunctionalTensor + from torch.nested._internal.nested_tensor import ( + _tensor_symint_registry, + NestedTensor, + ) + from torch.nested._internal.ops import extract_kwargs + + if x.layout != torch.jagged: + raise RuntimeError( + "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " + "Please raise an issue on GitHub." + ) + kwargs = extract_kwargs(x) + values = x._values + lengths = x._lengths + offsets = x._offsets + kwargs["offsets"] = set_(offsets) + if lengths is not None: + kwargs["lengths"] = set_(lengths) + ragged_source = lengths + else: + ragged_source = offsets + new_thing = kwargs.get("lengths", kwargs.get("offsets")) + if isinstance(new_thing, (FakeTensor, FunctionalTensor)): + from torch._subclasses.functional_tensor import ( + mb_unwrap_functional_tensor, + ) + + # Temporary hack until we have the union find + tgt = mb_unwrap_functional_tensor(new_thing) + src = mb_unwrap_functional_tensor(ragged_source) + tgt.nested_int_memo = src.nested_int_memo + else: + _tensor_symint_registry[new_thing] = _tensor_symint_registry[ + ragged_source + ] + + return NestedTensor( + set_(values), + **kwargs, + ) storage_offset = x.storage_offset() stride = x.stride() - return torch.empty_like(x, device=device).set_( + return x.new_empty(0, device=device).set_( untyped_storage, size=x.shape, stride=stride, @@ -10473,7 +10520,14 @@ def set_(x): ) result._consolidated = {"storage": storage_cast} if "metadata" in self._consolidated: - result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"]) + # faster than deepcopy + def copy_dict(d): + return { + k: v if not isinstance(v, dict) else copy_dict(v) + for k, v in d.items() + } + + result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) if non_blocking in (False, None): if device.type == "cuda" and non_blocking is False: # sending to CUDA force sync diff --git a/tensordict/utils.py b/tensordict/utils.py index e914722fd..7d5c0a624 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1540,16 +1540,26 @@ def assert_close( elif not isinstance(input1, torch.Tensor): continue if input1.is_nested: - input1 = input1._base - input2 = input2._base - mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() + input1v = input1.values() + input2v = input2.values() + mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum() + input1o = input1.offsets() + input2o = input2.offsets() + mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum() + else: + mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() mse = mse.div(input1.numel()).sqrt().item() local_msg = f"key {key} does not match, got mse = {mse:4.4f}" new_msg = ",\t".join([local_msg, msg]) if len(msg) else local_msg - torch.testing.assert_close( - input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg - ) + if input1.is_nested: + torch.testing.assert_close( + input1v, input2v, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg + ) + else: + torch.testing.assert_close( + input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg + ) local_msg = f"key {key} matches" msg = "\t".join([local_msg, msg]) if len(msg) else local_msg @@ -2650,3 +2660,36 @@ def parse_tensor_dict_string(s: str): raise ValueError("Device not found in the string") tensor_dict = TensorDict(fields, batch_size=torch.Size(batch_size), device=device) return tensor_dict + + +def _rebuild_njt_from_njt(x, values, offsets, lengths): + from torch._subclasses.fake_tensor import FakeTensor + from torch._subclasses.functional_tensor import FunctionalTensor + from torch.nested._internal.nested_tensor import ( + _tensor_symint_registry, + NestedTensor, + ) + from torch.nested._internal.ops import extract_kwargs + + kwargs = extract_kwargs(x) + kwargs["offsets"] = offsets + if x._lengths is not None: + kwargs["lengths"] = lengths + ragged_source = x._lengths + else: + ragged_source = x._offsets + new_thing = kwargs.get("lengths", kwargs.get("offsets")) + if isinstance(new_thing, (FakeTensor, FunctionalTensor)): + from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor + + # Temporary hack until we have the union find + tgt = mb_unwrap_functional_tensor(new_thing) + src = mb_unwrap_functional_tensor(ragged_source) + tgt.nested_int_memo = src.nested_int_memo + else: + _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source] + + return NestedTensor( + values, + **kwargs, + ) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index d0f00a738..099d94b25 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7916,7 +7916,8 @@ def check_id(a, b): @pytest.mark.skipif(not _v2_5, reason="v2.5 required for this test") @pytest.mark.parametrize("device", [None, *get_available_devices()]) @pytest.mark.parametrize("use_file", [False, True]) - def test_consolidate_njt(self, device, use_file, tmpdir): + @pytest.mark.parametrize("num_threads", [0, 1, 4]) + def test_consolidate_njt(self, device, use_file, tmpdir, num_threads): td = TensorDict( { "a": torch.arange(3).expand(4, 3).clone(), @@ -7937,29 +7938,24 @@ def test_consolidate_njt(self, device, use_file, tmpdir): ) if not use_file: - td_c = td.consolidate() + td_c = td.consolidate(num_threads=num_threads) assert td_c.device == device else: filename = Path(tmpdir) / "file.mmap" - td_c = td.consolidate(filename=filename) + td_c = td.consolidate(filename=filename, num_threads=num_threads) assert td_c.device == torch.device("cpu") assert assert_allclose_td(TensorDict.from_consolidated(filename), td_c) assert hasattr(td_c, "_consolidated") assert type(td_c) == type(td) # noqa assert td_c["d"] == "a string!" - with ( - pytest.raises(KeyError) - if td.device != td_c.device and device is not None - else contextlib.nullcontext() - ): - # njt.to(device) is currently broken when it has lengths - assert_allclose_td(td.to(td_c.device), td_c) + + assert_allclose_td(td.to(td_c.device), td_c) tdload_make, tdload_data = _reduce_td(td) tdload = tdload_make(*tdload_data) assert (td == tdload).all() - td_c = td.consolidate() + td_c = td.consolidate(num_threads=num_threads) tdload_make, tdload_data = _reduce_td(td_c) tdload = tdload_make(*tdload_data) assert assert_allclose_td(td, tdload) @@ -7999,6 +7995,47 @@ def test_consolidate_to_device(self): assert td_c_device["d"] == [["a string!"] * 3] assert len(dataptrs) == 1 + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device detected") + def test_consolidate_to_device_njt(self): + td = TensorDict( + { + "a": torch.arange(3).expand(4, 3).clone(), + "d": "a string!", + "njt": torch.nested.nested_tensor_from_jagged( + torch.arange(10), offsets=torch.tensor([0, 2, 5, 8, 10]) + ), + "njt_lengths": torch.nested.nested_tensor_from_jagged( + torch.arange(10), + offsets=torch.tensor([0, 2, 5, 8, 10]), + lengths=torch.tensor([2, 3, 3, 2]), + ), + }, + device="cpu", + batch_size=[4], + ) + device = torch.device("cuda:0") + td_c = td.consolidate() + assert td_c.device == torch.device("cpu") + td_c_device = td_c.to(device) + assert td_c_device.device == device + assert td_c_device.is_consolidated() + dataptrs = set() + for tensor in td_c_device.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + assert tensor.device == device + if tensor.is_nested: + vals = tensor._values + dataptrs.add(vals.untyped_storage().data_ptr()) + offsets = tensor._offsets + dataptrs.add(offsets.untyped_storage().data_ptr()) + lengths = tensor._lengths + if lengths is not None: + dataptrs.add(lengths.untyped_storage().data_ptr()) + else: + dataptrs.add(tensor.untyped_storage().data_ptr()) + assert len(dataptrs) == 1 + assert assert_allclose_td(td_c_device.cpu(), td) + assert td_c_device["njt_lengths"]._lengths is not None + def test_create_empty(self): td = LazyStackedTensorDict(stack_dim=0) assert td.device is None