Skip to content

Commit

Permalink
[Feature] TD+NJT to(device) support
Browse files Browse the repository at this point in the history
ghstack-source-id: 5f84ebc2a01e6dab26fe1d68d67bb166a295e885
Pull Request resolved: #1022
  • Loading branch information
vmoens committed Oct 16, 2024
1 parent 7e45bcc commit e696708
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 52 deletions.
6 changes: 5 additions & 1 deletion tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,19 @@ def from_metadata(metadata=metadata, prefix=None):
value = value[: local_shape.numel()]
value = value.view(local_shape)
if key.startswith("<NJT>"):
raise RuntimeError
elif key.startswith("<NJT_VALUES>"):
nested_values = value
nested_lengths = None
continue
elif key.startswith("<NJT_LENGTHS>"):
nested_lengths = value
continue
elif key.startswith("<NJT_OFFSETS>"):
from torch.nested._internal.nested_tensor import NestedTensor

offsets = value
value = torch.nested.nested_tensor_from_jagged(
value = NestedTensor(
nested_values, offsets=offsets, lengths=nested_lengths
)
key = key.replace("<NJT_OFFSETS>", "")
Expand Down
122 changes: 88 additions & 34 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections.abc import MutableMapping

from concurrent.futures import Future, ThreadPoolExecutor, wait
from copy import copy, deepcopy
from copy import copy
from functools import partial, wraps
from pathlib import Path
from textwrap import indent
Expand Down Expand Up @@ -66,6 +66,7 @@
_prefix_last_key,
_proc_init,
_prune_selected_keys,
_rebuild_njt_from_njt,
_set_max_batch_size,
_shape,
_split_tensordict,
Expand Down Expand Up @@ -3591,7 +3592,7 @@ def assign(
if getattr(value, "is_nested", False):
if value.layout is torch.jagged:
# Get the values
values = value.values()
values = value._values
shape = [v if isinstance(v, int) else -1 for v in values.shape]
# Get the offsets
offsets = value._offsets
Expand All @@ -3602,10 +3603,14 @@ def assign(
# We will rely on the fact that the writing order is preserved in python dict
# (since python 3.7). Later, we will read the NJT then the NJT offset in that order
# to do the allocation.
flat_key_values[_prefix_last_key(total_key, "<NJT>")] = values
flat_key_values[_prefix_last_key(total_key, "<NJT>")] = value
flat_size.append(0)
flat_key_values[_prefix_last_key(total_key, "<NJT_VALUES>")] = (
values
)
add_single_value(
values,
_prefix_last_key(key, "<NJT>"),
_prefix_last_key(key, "<NJT_VALUES>"),
metadata_dict,
values.dtype,
shape,
Expand Down Expand Up @@ -3811,12 +3816,14 @@ def assign(
start,
stop,
njts,
njts_offsets,
njts_lengths,
storage=storage,
non_blocking=non_blocking,
):
"""Reads a slice of the storage and assigns the resulting tensor in flat_dict."""
# v may need padding
if k[-1].startswith("<NJT>"):
njts[k] = v
return
v_pad = v.view(-1).view(torch.uint8)
exp_length = stop - start
pad = exp_length - v_pad.numel()
Expand All @@ -3830,17 +3837,9 @@ def assign(
if pad:
new_v = new_v[: v.numel()]
new_v = new_v.view(shape)
if k[-1].startswith("<NJT>"):
njts[k] = new_v
elif k[-1].startswith("<NJT_LENGTHS>"):
njts_lengths[k] = new_v
elif k[-1].startswith("<NJT_OFFSETS>"):
njts_offsets[k] = new_v
flat_dict[k] = new_v

njts = {}
njts_offsets = {}
njts_lengths = {}
if num_threads > 1:
executor = ThreadPoolExecutor(num_threads)
r = []
Expand All @@ -3853,8 +3852,6 @@ def assign(
start=offsets[i],
stop=offsets[i + 1],
njts=njts,
njts_offsets=njts_offsets,
njts_lengths=njts_lengths,
)
)
if not return_early:
Expand All @@ -3872,25 +3869,25 @@ def assign(
start=offsets[i],
stop=offsets[i + 1],
njts=njts,
njts_offsets=njts_offsets,
njts_lengths=njts_lengths,
)
for njt_key, njt_val in njts.items():
for njt_key, njt in njts.items():
newkey = njt_key[:-1] + (njt_key[-1].replace("<NJT>", ""),)
njt_key_values = njt_key[:-1] + (
njt_key[-1].replace("<NJT>", "<NJT_VALUES>"),
)
njt_key_offset = njt_key[:-1] + (
njt_key[-1].replace("<NJT>", "<NJT_OFFSETS>"),
)
njt_key_lengths = njt_key[:-1] + (
njt_key[-1].replace("<NJT>", "<NJT_LENGTHS>"),
)
val = torch.nested.nested_tensor_from_jagged(
njt_val,
offsets=flat_dict[njt_key_offset],
lengths=flat_dict.get(njt_key_lengths),
val = _rebuild_njt_from_njt(
njt,
values=flat_dict.pop(njt_key_values),
offsets=flat_dict.pop(njt_key_offset),
lengths=flat_dict.pop(njt_key_lengths, None),
)
del flat_dict[njt_key]
del flat_dict[njt_key_offset]
flat_dict.pop(njt_key_lengths, None)
newkey = njt_key[:-1] + (njt_key[-1].replace("<NJT>", ""),)
flat_dict[newkey] = val

if non_blocking and device.type != "cuda":
Expand All @@ -3910,6 +3907,8 @@ def _view_and_pad(tensor):

items = []
for v in flat_dict.values():
if v.is_nested:
continue
if v.device != storage.device:
v = v.to(storage.device, non_blocking=non_blocking)
stride = v.stride()
Expand All @@ -3928,9 +3927,13 @@ def _view_and_pad(tensor):
flat_dict[k] = view_old_as_new(v, oldv)
elif k[-1].startswith("<NJT>"):
# NJT/NT always comes before offsets/shapes
_nested_values = view_old_as_new(v, oldv)
nt = oldv
assert not v.numel()
nt_lengths = None
del flat_dict[k]
elif k[-1].startswith("<NJT_VALUES>"):
nt_vaues = view_old_as_new(v, oldv)
del flat_dict[k]
elif k[-1].startswith("<NJT_LENGTHS>"):
nt_lengths = view_old_as_new(v, oldv)
del flat_dict[k]
Expand All @@ -3939,15 +3942,16 @@ def _view_and_pad(tensor):
nt_offsets = view_old_as_new(v, oldv)
del flat_dict[k]

flat_dict[newk] = torch.nested.nested_tensor_from_jagged(
_nested_values,
offsets=nt_offsets,
lengths=nt_lengths,
val = _rebuild_njt_from_njt(
nt, values=nt_vaues, offsets=nt_offsets, lengths=nt_lengths
)

flat_dict[newk] = val

# delete the nested value to make sure that if there was an
# ordering mismatch we wouldn't be looking at the value key of
# another nested tensor.
del _nested_values
del nt, nt_vaues, nt_offsets, nt_lengths
else:
flat_dict[k] = view_old_as_new(v, oldv)

Expand Down Expand Up @@ -10459,9 +10463,52 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
untyped_storage = storage_cast.untyped_storage()

def set_(x):
if x.is_nested:
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.nested._internal.nested_tensor import (
_tensor_symint_registry,
NestedTensor,
)
from torch.nested._internal.ops import extract_kwargs

if x.layout != torch.jagged:
raise RuntimeError(
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
"Please raise an issue on GitHub."
)
kwargs = extract_kwargs(x)
values = x._values
lengths = x._lengths
offsets = x._offsets
kwargs["offsets"] = set_(offsets)
if lengths is not None:
kwargs["lengths"] = set_(lengths)
ragged_source = lengths
else:
ragged_source = offsets
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
from torch._subclasses.functional_tensor import (
mb_unwrap_functional_tensor,
)

# Temporary hack until we have the union find
tgt = mb_unwrap_functional_tensor(new_thing)
src = mb_unwrap_functional_tensor(ragged_source)
tgt.nested_int_memo = src.nested_int_memo
else:
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
ragged_source
]

return NestedTensor(
set_(values),
**kwargs,
)
storage_offset = x.storage_offset()
stride = x.stride()
return torch.empty_like(x, device=device).set_(
return x.new_empty(0, device=device).set_(
untyped_storage,
size=x.shape,
stride=stride,
Expand All @@ -10473,7 +10520,14 @@ def set_(x):
)
result._consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"])
# faster than deepcopy
def copy_dict(d):
return {
k: v if not isinstance(v, dict) else copy_dict(v)
for k, v in d.items()
}

result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
Expand Down
55 changes: 49 additions & 6 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,16 +1540,26 @@ def assert_close(
elif not isinstance(input1, torch.Tensor):
continue
if input1.is_nested:
input1 = input1._base
input2 = input2._base
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
input1v = input1.values()
input2v = input2.values()
mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum()
input1o = input1.offsets()
input2o = input2.offsets()
mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum()
else:
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
mse = mse.div(input1.numel()).sqrt().item()

local_msg = f"key {key} does not match, got mse = {mse:4.4f}"
new_msg = ",\t".join([local_msg, msg]) if len(msg) else local_msg
torch.testing.assert_close(
input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
)
if input1.is_nested:
torch.testing.assert_close(
input1v, input2v, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
)
else:
torch.testing.assert_close(
input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
)
local_msg = f"key {key} matches"
msg = "\t".join([local_msg, msg]) if len(msg) else local_msg

Expand Down Expand Up @@ -2650,3 +2660,36 @@ def parse_tensor_dict_string(s: str):
raise ValueError("Device not found in the string")
tensor_dict = TensorDict(fields, batch_size=torch.Size(batch_size), device=device)
return tensor_dict


def _rebuild_njt_from_njt(x, values, offsets, lengths):
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.nested._internal.nested_tensor import (
_tensor_symint_registry,
NestedTensor,
)
from torch.nested._internal.ops import extract_kwargs

kwargs = extract_kwargs(x)
kwargs["offsets"] = offsets
if x._lengths is not None:
kwargs["lengths"] = lengths
ragged_source = x._lengths
else:
ragged_source = x._offsets
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor

# Temporary hack until we have the union find
tgt = mb_unwrap_functional_tensor(new_thing)
src = mb_unwrap_functional_tensor(ragged_source)
tgt.nested_int_memo = src.nested_int_memo
else:
_tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]

return NestedTensor(
values,
**kwargs,
)
Loading

0 comments on commit e696708

Please sign in to comment.