Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor keys, items and values #1058

Merged
merged 4 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _has_exclusive_keys(self):
return False

@_fails_exclusive_keys
def to_dict(self) -> dict[str, Any]: ...
def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]: ...

def _reduce_get_metadata(self):
metadata = {}
Expand Down Expand Up @@ -3417,7 +3417,7 @@ def _select(
) -> _CustomOpTensorDict:
if inplace:
raise RuntimeError("Cannot call select inplace on a lazy tensordict.")
return self.to_tensordict()._select(
return self.to_tensordict(retain_none=True)._select(
*keys, inplace=False, strict=strict, set_shared=set_shared
)

Expand Down
8 changes: 5 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -4210,7 +4210,7 @@ def _iter():
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
if not self.is_leaf(target_class):
continue
yield key
else:
Expand Down Expand Up @@ -4239,9 +4239,11 @@ def _iter_helper(
# For lazy stacks
value = value[0]
cls = type(value)
is_tc = _is_tensor_collection(cls)
if self.include_nested and is_tc:
if not is_non_tensor(cls):
yield from self._iter_helper(value, prefix=full_key)
is_leaf = self.is_leaf(cls)
if self.include_nested and not is_leaf:
yield from self._iter_helper(value, prefix=full_key)
if not self.leaves_only or is_leaf:
yield full_key

Expand Down
153 changes: 80 additions & 73 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5570,38 +5570,39 @@ def items(
Defaults to ``False``.

"""
if is_leaf is None:
is_leaf = _default_is_leaf
if sort:
yield from sorted(
self.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
),
key=lambda item: (
item[0] if isinstance(item[0], str) else ".".join(item[0])
),
)
else:

def _items():
if include_nested and leaves_only:
if is_leaf is None:
is_leaf = _default_is_leaf

if include_nested:
# check the conditions once only
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
else:
cls = type(val)
if not leaves_only or is_leaf(cls):
yield k, val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield k, val
if not is_leaf(type(val)):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
if _is_tensor_collection(cls):
if not is_non_tensor(cls):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
)
elif leaves_only:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
Expand All @@ -5611,16 +5612,6 @@ def _items():
for k in self.keys():
yield k, self._get_str(k, NO_DEFAULT)

if sort:
yield from sorted(
_items(),
key=lambda item: (
item[0] if isinstance(item[0], str) else ".".join(item[0])
),
)
else:
yield from _items()

def non_tensor_items(self, include_nested: bool = False):
"""Returns all non-tensor leaves, maybe recursively."""
return tuple(
Expand Down Expand Up @@ -5657,32 +5648,28 @@ def values(
Defaults to ``False``.

"""
if is_leaf is None:
is_leaf = _default_is_leaf
if sort:
for _, value in self.items(include_nested, leaves_only, is_leaf, sort=sort):
yield value
else:

if is_leaf is None:
is_leaf = _default_is_leaf

def _values():
# check the conditions once only
if include_nested and leaves_only:
if include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
else:
cls = type(val)
if not leaves_only or is_leaf(cls):
yield val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield val
if not is_leaf(type(val)):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
if include_nested and _is_tensor_collection(cls):
if not is_non_tensor(cls):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
elif leaves_only:
for k in self.keys(sort=sort):
val = self._get_str(k, NO_DEFAULT)
Expand All @@ -5692,12 +5679,6 @@ def _values():
for k in self.keys(sort=sort):
yield self._get_str(k, NO_DEFAULT)

if not sort or not include_nested:
yield from _values()
else:
for _, value in self.items(include_nested, leaves_only, is_leaf, sort=sort):
yield value

@cache # noqa: B019
def _values_list(
self,
Expand Down Expand Up @@ -9350,9 +9331,16 @@ def _maybe_set_shared_attributes(self, result, lock=False):
if lock:
result.lock_()

def to_tensordict(self) -> T:
def to_tensordict(self, *, retain_none: bool | None = None) -> T:
"""Returns a regular TensorDict instance from the TensorDictBase.

Args:
retain_none (bool): if ``True``, the ``None`` values from tensorclass instances
will be written in the tensordict.
Otherwise they will be discarded. Default: ``True``.

.. note:: from v0.8, the default value will be switched to ``False``.

Returns:
a new TensorDict object containing the same values.

Expand All @@ -9364,7 +9352,11 @@ def to_tensordict(self) -> T:
key: (
value.clone()
if not _is_tensor_collection(type(value))
else value if is_non_tensor(value) else value.to_tensordict()
else (
value
if is_non_tensor(value)
else value.to_tensordict(retain_none=retain_none)
)
)
for key, value in self.items(is_leaf=_is_leaf_nontensor)
},
Expand Down Expand Up @@ -9467,12 +9459,27 @@ def as_tensor(tensor):

return self._fast_apply(as_tensor, propagate_lock=True)

def to_dict(self) -> dict[str, Any]:
"""Returns a dictionary with key-value pairs matching those of the tensordict."""
return {
key: value.to_dict() if _is_tensor_collection(type(value)) else value
for key, value in self.items()
}
def to_dict(self, *, retain_none: bool = True) -> dict[str, Any]:
"""Returns a dictionary with key-value pairs matching those of the tensordict.

Args:
retain_none (bool): if ``True``, the ``None`` values from tensorclass instances
will be written in the dictionary.
Otherwise, they will be discarded. Default: ``True``.

"""
result = {}
for key, value in self.items():
if _is_tensor_collection(type(value)):
if (
not retain_none
and _is_non_tensor(type(value))
and value.data is None
):
continue
value = value.to_dict(retain_none=retain_none)
result[key] = value
return result

def numpy(self):
"""Converts a tensordict to a (possibly nested) dictionary of numpy arrays.
Expand Down Expand Up @@ -9500,7 +9507,7 @@ def numpy(self):
{'a': {'b': array(0., dtype=float32), 'c': 'a string!'}}

"""
as_dict = self.to_dict()
as_dict = self.to_dict(retain_none=False)

def to_numpy(x):
if isinstance(x, torch.Tensor):
Expand Down Expand Up @@ -9541,7 +9548,7 @@ def dict_to_namedtuple(dictionary):
)
return cls(**dictionary)

return dict_to_namedtuple(self.to_dict())
return dict_to_namedtuple(self.to_dict(retain_none=False))

@classmethod
def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def from_dict(cls, *args, **kwargs):
return TensorDictParams(td)

@_fallback
def to_tensordict(self): ...
def to_tensordict(self, *, retain_none: bool | None = None): ...

@_fallback
def to_h5(
Expand Down
Loading
Loading