Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 9, 2024
1 parent 044f888 commit f7665db
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
6 changes: 3 additions & 3 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2127,10 +2127,10 @@ def __getitem__(self, index: IndexType) -> Any:
if index_key:
leaf = self._get_tuple(index_key, NO_DEFAULT)
if is_non_tensor(leaf):
result = getattr(leaf, "data", NO_DEFAULT)
if result is NO_DEFAULT:
# Only lazy stacks of non tensors are actually tensordict instances
if isinstance(leaf, TensorDictBase):
return leaf.tolist()
return result
return leaf.data
return leaf
split_index = self._split_index(index)
converted_idx = split_index["index_dict"]
Expand Down
6 changes: 3 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6268,10 +6268,10 @@ def _get_tuple(self, key, default): ...
def _get_tuple_maybe_non_tensor(self, key, default):
result = self._get_tuple(key, default)
if is_non_tensor(result):
result_data = getattr(result, "data", NO_DEFAULT)
if result_data is NO_DEFAULT:
# Only lazy stacks of non tensors are actually tensordict instances
if isinstance(result, TensorDictBase):
return result.tolist()
return result_data
return result.data
return result

def get_at(
Expand Down
11 changes: 10 additions & 1 deletion tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3445,7 +3445,16 @@ def update_at_(

@property
def data(self):
raise AttributeError
"""Attempts to return the unique value in the stack.
Raises a ValueError if there is more than one unique value.
"""
try:
return NonTensorData._stack_non_tensor(
self.tensordicts, raise_if_non_unique=True
).data
except ValueError:
raise AttributeError("Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead.")


_register_tensor_class(NonTensorStack)
Expand Down

0 comments on commit f7665db

Please sign in to comment.