Skip to content

Commit

Permalink
[Refactor] __eq__ to identity check in non-tensor stacking
Browse files Browse the repository at this point in the history
ghstack-source-id: ccbe882e12370b4145d7d834012cc3cfa6376f6c
Pull Request resolved: #1083
  • Loading branch information
vmoens committed Nov 11, 2024
1 parent 9607cf0 commit b39f0db
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 28 deletions.
1 change: 0 additions & 1 deletion tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def _rebuild_buffer(data, requires_grad, backward_hooks):

def _dispatch_td_nn_modules():
"""Returns ``True`` if @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile."""
global DISPATCH_TDNN_MODULES
return DISPATCH_TDNN_MODULES


Expand Down
55 changes: 41 additions & 14 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,20 +2748,23 @@ def _stack_non_tensor(cls, list_of_non_tensor, dim=0):
# checks have been performed previously, so we're sure the list is non-empty
first = list_of_non_tensor[0]

def _check_equal(a, b):
try:
if isinstance(a, _ACCEPTED_CLASSES) or isinstance(b, _ACCEPTED_CLASSES):
return (a == b).all() and a.shape == b.shape
if isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
return (a == b).all() and a.shape == b.shape
iseq = a == b
except Exception:
iseq = False
return iseq

if all(isinstance(data, NonTensorData) for data in list_of_non_tensor) and all(
_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]
):
ids = set()
firstdata = NO_DEFAULT
for data in list_of_non_tensor:
if not isinstance(data, NonTensorData):
return_stack = True
break
if firstdata is NO_DEFAULT:
firstdata = data.data
ids.add(id(data.data))
if len(ids) > 1:
if _check_equal(data.data, firstdata):
continue
return_stack = True
break
else:
return_stack = False
if not return_stack:
batch_size = list(first.batch_size)
batch_size.insert(dim, len(list_of_non_tensor))
return NonTensorData(
Expand Down Expand Up @@ -3442,3 +3445,27 @@ class TensorClass(metaclass=_TensorClassMeta):
_nocast: bool = False
_frozen: bool = False
...


# TODO: v0.8: remove this func entirely
def _check_equal(a, b):
# A util to check that two non-tensor data match
# We're replacing this by an identity match, not a value check (which will be faster and easier to handle).
try:
if isinstance(a, _ACCEPTED_CLASSES) or isinstance(b, _ACCEPTED_CLASSES):
iseq = (a == b).all() and a.shape == b.shape
elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
iseq = (a == b).all() and a.shape == b.shape
else:
iseq = bool(a == b)
except Exception:
iseq = False
if iseq:
warnings.warn(
"The content of the stacked NonTensorData objects matched in value but not identity. "
"This will currently return a NonTensorData but in the future (v0.8) it will return "
"a NonTensorStack instead. "
"To obtain a non-tensor stack, use `TensorDict.lazy_stack` instead.",
category=UserWarning,
)
return iseq
2 changes: 1 addition & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2759,7 +2759,7 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths):
tgt = mb_unwrap_functional_tensor(new_thing)
src = mb_unwrap_functional_tensor(ragged_source)
tgt.nested_int_memo = src.nested_int_memo
else:
elif new_thing is not None:
_tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]

return NestedTensor(
Expand Down
14 changes: 2 additions & 12 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,18 +1909,8 @@ def z(self) -> torch.Tensor:
return self._z()

obj = torch.ones(())
y0 = Y(
weakref.ref(obj),
batch_size=[
1,
],
)
y1 = Y(
weakref.ref(obj),
batch_size=[
1,
],
)
y0 = Y(weakref.ref(obj), batch_size=[1])
y1 = Y(weakref.ref(obj), batch_size=[1])
y = torch.cat([y0, y1])
assert y.z.shape == torch.Size(())
y = torch.stack([y0, y1])
Expand Down
14 changes: 14 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@
pytest.mark.filterwarnings(
"ignore:Lazy modules are a new feature under heavy development so changes to the API or functionality"
),
pytest.mark.filterwarnings(
"ignore:The content of the stacked NonTensorData objects matched in value but not identity"
),
]

mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn"
Expand Down Expand Up @@ -10881,6 +10884,17 @@ def test_memmap_stack(self, tmpdir, json_serializable, device):
assert data_memmap._is_memmap

def test_memmap_stack_updates(self, tmpdir):
with pytest.warns(
UserWarning,
match="The content of the stacked NonTensorData objects matched in value but not identity",
):
data = torch.stack(
[
NonTensorData(data=torch.zeros(())),
NonTensorData(data=torch.zeros(())),
],
0,
)
data = torch.stack([NonTensorData(data=0), NonTensorData(data=1)], 0)
assert is_non_tensor(data)
data = torch.stack([data] * 3)
Expand Down

0 comments on commit b39f0db

Please sign in to comment.