Skip to content

Commit

Permalink
[Feature] NonTensorStack.data
Browse files Browse the repository at this point in the history
ghstack-source-id: 86065377cc1cd7c7283ed0a468f5d5602d60526d
Pull Request resolved: #1132
  • Loading branch information
vmoens committed Dec 9, 2024
1 parent 1d93434 commit 4404abe
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
6 changes: 3 additions & 3 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,10 +2130,10 @@ def __getitem__(self, index: IndexType) -> Any:
if index_key:
leaf = self._get_tuple(index_key, NO_DEFAULT)
if is_non_tensor(leaf):
result = getattr(leaf, "data", NO_DEFAULT)
if result is NO_DEFAULT:
# Only lazy stacks of non tensors are actually tensordict instances
if isinstance(leaf, TensorDictBase):
return leaf.tolist()
return result
return leaf.data
return leaf
split_index = self._split_index(index)
converted_idx = split_index["index_dict"]
Expand Down
6 changes: 3 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6301,10 +6301,10 @@ def _get_tuple(self, key, default): ...
def _get_tuple_maybe_non_tensor(self, key, default):
result = self._get_tuple(key, default)
if is_non_tensor(result):
result_data = getattr(result, "data", NO_DEFAULT)
if result_data is NO_DEFAULT:
# Only lazy stacks of non tensors are actually tensordict instances
if isinstance(result, TensorDictBase):
return result.tolist()
return result_data
return result.data
return result

def get_at(
Expand Down
11 changes: 10 additions & 1 deletion tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3450,7 +3450,16 @@ def update_at_(

@property
def data(self):
raise AttributeError
"""Attempts to return the unique value in the stack.
Raises a ValueError if there is more than one unique value.
"""
try:
return NonTensorData._stack_non_tensor(
self.tensordicts, raise_if_non_unique=True
).data
except ValueError:
raise AttributeError("Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead.")


_register_tensor_class(NonTensorStack)
Expand Down

0 comments on commit 4404abe

Please sign in to comment.