Skip to content

Commit

Permalink
[BugFix] Better repr of lazy stacks
Browse files Browse the repository at this point in the history
ghstack-source-id: 7256b4c95b239bf9e6467c0ea687abe2c9179922
Pull Request resolved: #1076

(cherry picked from commit eaba711)
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent d3bcb6e commit e00965c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,13 +1615,21 @@ def _td_fields(td: T, keys=None, sep=": ") -> str:
# we know td is lazy stacked and the key is a leaf
# so we can get the shape and escape the error
temp_td = td
from tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict import (
is_tensor_collection,
LazyStackedTensorDict,
TensorDictBase,
)

while isinstance(
temp_td, LazyStackedTensorDict
): # we need to grab the het tensor from the inner nesting level
): # we need to grab the heterogeneous tensor from the inner nesting level
temp_td = temp_td.tensordicts[0]
tensor = temp_td.get(key)
if is_tensor_collection(tensor):
tensor = td.get(key)
strs.append(_make_repr(key, tensor, td, sep=sep))
continue

if isinstance(tensor, TensorDictBase):
substr = _td_fields(tensor)
Expand Down
37 changes: 37 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7168,6 +7168,43 @@ def test_repr_nested(self, device, dtype):
is_shared={is_shared})"""
assert repr(nested_td) == expected

def test_repr_nested_lazy(self, device, dtype):
nested_td0 = self.nested_td(device, dtype)
nested_td1 = torch.cat([nested_td0, nested_td0], 1)
nested_td1["my_nested_td", "another"] = nested_td1["my_nested_td", "a"]
lazy_nested_td = TensorDict.lazy_stack([nested_td0, nested_td1], dim=1)

if device is not None and device.type == "cuda":
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
tensor_device = device if device else nested_td0[:, 0]["b"].device
if tensor_device.type == "cuda":
is_shared_tensor = True
else:
is_shared_tensor = is_shared
expected = f"""LazyStackedTensorDict(
fields={{
b: {tensor_class}(shape=torch.Size([4, 2, -1, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}),
my_nested_td: LazyStackedTensorDict(
fields={{
a: {tensor_class}(shape=torch.Size([4, 2, -1, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}},
exclusive_fields={{
1 ->
another: Tensor(shape=torch.Size([4, 6, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}},
batch_size=torch.Size([4, 2, -1, 2, 1]),
device={str(device)},
is_shared={is_shared},
stack_dim=1)}},
exclusive_fields={{
}},
batch_size=torch.Size([4, 2, -1, 2, 1]),
device={str(device)},
is_shared={is_shared},
stack_dim=1)"""
assert repr(lazy_nested_td) == expected

def test_repr_nested_update(self, device, dtype):
nested_td = self.nested_td(device, dtype)
nested_td["my_nested_td"].rename_key_("a", "z")
Expand Down

0 comments on commit e00965c

Please sign in to comment.