Skip to content

Commit

Permalink
[Feature] min, amin, max, amax, cummin, cummax
Browse files Browse the repository at this point in the history
ghstack-source-id: 81d6836892b182e60cdbc9ef9ebb6637ad611518
Pull Request resolved: #1057
  • Loading branch information
vmoens committed Oct 24, 2024
1 parent 8c65dcb commit 9ca4207
Show file tree
Hide file tree
Showing 6 changed files with 398 additions and 7 deletions.
3 changes: 2 additions & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,8 @@ def _apply_nest(
arg for arg in (batch_size, device, names, constructor_kwargs)
):
raise ValueError(
"Cannot pass other arguments to LazyStackedTensorDict.apply when inplace=True."
"Cannot pass other arguments to LazyStackedTensorDict.apply when inplace=True. Got args "
f"batch_size={batch_size}, device={device}, names={names}, constructor_kwargs={constructor_kwargs}"
)
if out is not None:
if not isinstance(out, LazyStackedTensorDict):
Expand Down
14 changes: 12 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,9 @@ def _cast_reduction(
keepdim=NO_DEFAULT,
tuple_ok=True,
further_reduce: bool,
values_only: bool = True,
call_on_nested: bool = True,
batch_size=None,
**kwargs,
):
if further_reduce:
Expand Down Expand Up @@ -1015,9 +1018,16 @@ def reduction(val):
result = getattr(val, reduction_name)(
**kwargs,
)
if isinstance(result, tuple):
if values_only:
result = result.values
else:
return TensorDict.from_namedtuple(result)
return result

if dim not in (None, NO_DEFAULT):
if batch_size is not None:
pass
elif dim is not None and dim is not NO_DEFAULT:
if not keepdim:
if isinstance(dim, tuple):
batch_size = [
Expand All @@ -1043,7 +1053,7 @@ def reduction(val):

return self._fast_apply(
reduction,
call_on_nested=True,
call_on_nested=call_on_nested,
batch_size=torch.Size(batch_size),
device=self.device,
names=names,
Expand Down
246 changes: 243 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class _NoDefault(enum.IntEnum):


NO_DEFAULT = _NoDefault.ZERO

assert not NO_DEFAULT
T = TypeVar("T", bound="TensorDictBase")


Expand Down Expand Up @@ -575,6 +575,232 @@ def isreal(self) -> T:
propagate_lock=True,
)

def amin(
self,
dim: int | NO_DEFAULT = NO_DEFAULT,
keepdim: bool = False,
*,
reduce: bool | None = None,
) -> TensorDictBase | torch.Tensor: # noqa: D417
"""Returns the minimum values of all elements in the input tensordict.
Same as :meth:`~.min` with ``return_indices=False``.
"""
return self.min(dim=dim, keepdim=keepdim, reduce=reduce, return_indices=False)

def min(
self,
dim: int | NO_DEFAULT = NO_DEFAULT,
keepdim: bool = False,
*,
reduce: bool | None = None,
return_indices: bool = True,
) -> TensorDictBase | torch.Tensor: # noqa: D417
"""Returns the minimum values of all elements in the input tensordict.
Args:
dim (int, optional): if ``None``, returns a dimensionless
tensordict containing the min value of all leaves (if this can be computed).
If integer, `min` is called upon the dimension specified if
and only if this dimension is compatible with the tensordict
shape.
keepdim (bool): whether the output tensor has dim retained or not.
Keyword Args:
reduce (bool, optional): if ``True``, the reduciton will occur across all TensorDict values
and a single reduced tensor will be returned.
Defaults to ``False``.
return_argmins (bool, optional): :func:`~torch.min` returns a named tuple with values and indices
when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass
with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``.
"""
result = self._cast_reduction(
reduction_name="min",
dim=dim,
keepdim=keepdim,
further_reduce=reduce,
tuple_ok=False,
values_only=not return_indices,
call_on_nested=False,
)
if dim is not NO_DEFAULT and return_indices:
# Split the tensordict
from .return_types import min

values_dict = {}
indices_dict = {}
for key in result.keys(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
if key[-1] == "values":
values_dict[key] = key[:-1]
else:
indices_dict[key] = key[:-1]
return min(
*result.split_keys(values_dict, indices_dict),
batch_size=result.batch_size,
)
return result

def amax(
self,
dim: int | NO_DEFAULT = NO_DEFAULT,
keepdim: bool = False,
*,
reduce: bool | None = None,
) -> TensorDictBase | torch.Tensor: # noqa: D417
"""Returns the maximum values of all elements in the input tensordict.
Same as :meth:`~.max` with ``return_indices=False``.
"""
return self.max(dim=dim, keepdim=keepdim, reduce=reduce, return_indices=False)

def max(
self,
dim: int | NO_DEFAULT = NO_DEFAULT,
keepdim: bool = False,
*,
reduce: bool | None = None,
return_indices: bool = True,
) -> TensorDictBase | torch.Tensor: # noqa: D417
"""Returns the maximum values of all elements in the input tensordict.
Args:
dim (int, optional): if ``None``, returns a dimensionless
tensordict containing the max value of all leaves (if this can be computed).
If integer, `max` is called upon the dimension specified if
and only if this dimension is compatible with the tensordict
shape.
keepdim (bool): whether the output tensor has dim retained or not.
Keyword Args:
reduce (bool, optional): if ``True``, the reduciton will occur across all TensorDict values
and a single reduced tensor will be returned.
Defaults to ``False``.
return_argmins (bool, optional): :func:`~torch.max` returns a named tuple with values and indices
when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass
with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``.
"""
result = self._cast_reduction(
reduction_name="max",
dim=dim,
keepdim=keepdim,
further_reduce=reduce,
tuple_ok=False,
values_only=not return_indices,
call_on_nested=False,
)
if dim is not NO_DEFAULT and return_indices:
# Split the tensordict
from .return_types import max

values_dict = {}
indices_dict = {}
for key in result.keys(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
if key[-1] == "values":
values_dict[key] = key[:-1]
else:
indices_dict[key] = key[:-1]
return max(
*result.split_keys(values_dict, indices_dict),
batch_size=result.batch_size,
)
return result

def cummin(
self,
dim: int,
*,
reduce: bool | None = None,
return_indices: bool = True,
) -> TensorDictBase | torch.Tensor: # noqa: D417
"""Returns the cumulative minimum values of all elements in the input tensordict.
Args:
dim (int): integer representing the dimension along which to perform the cummin operation.
Keyword Args:
reduce (bool, optional): if ``True``, the reduciton will occur across all TensorDict values
and a single reduced tensor will be returned.
Defaults to ``False``.
return_argmins (bool, optional): :func:`~torch.cummin` returns a named tuple with values and indices
when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass
with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``.
"""
result = self._cast_reduction(
reduction_name="cummin",
dim=dim,
further_reduce=reduce,
tuple_ok=False,
values_only=not return_indices,
call_on_nested=False,
batch_size=self.batch_size,
)
if dim is not NO_DEFAULT and return_indices:
# Split the tensordict
from .return_types import cummin

values_dict = {}
indices_dict = {}
for key in result.keys(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
if key[-1] == "values":
values_dict[key] = key[:-1]
else:
indices_dict[key] = key[:-1]
return cummin(
*result.split_keys(values_dict, indices_dict),
batch_size=result.batch_size,
)
return result

def cummax(
self,
dim: int,
*,
reduce: bool | None = None,
return_indices: bool = True,
) -> TensorDictBase | torch.Tensor: # noqa: D417
"""Returns the cumulative maximum values of all elements in the input tensordict.
Args:
dim (int): integer representing the dimension along which to perform the cummax operation.
Keyword Args:
reduce (bool, optional): if ``True``, the reduciton will occur across all TensorDict values
and a single reduced tensor will be returned.
Defaults to ``False``.
return_argmins (bool, optional): :func:`~torch.cummax` returns a named tuple with values and indices
when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass
with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``.
"""
result = self._cast_reduction(
reduction_name="cummax",
dim=dim,
further_reduce=reduce,
tuple_ok=False,
values_only=not return_indices,
call_on_nested=False,
batch_size=self.batch_size,
)
if dim is not NO_DEFAULT and return_indices:
# Split the tensordict
from .return_types import cummax

values_dict = {}
indices_dict = {}
for key in result.keys(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
if key[-1] == "values":
values_dict[key] = key[:-1]
else:
indices_dict[key] = key[:-1]
return cummax(
*result.split_keys(values_dict, indices_dict),
batch_size=result.batch_size,
)
return result

def mean(
self,
dim: int | Tuple[int] = NO_DEFAULT,
Expand Down Expand Up @@ -4840,7 +5066,9 @@ def _filter(x):

def filter_empty_(self):
"""Filters out all empty tensordicts in-place."""
for key, val in list(self.items(True, is_leaf=_NESTED_TENSORS_AS_LISTS)):
for key, val in reversed(
list(self.items(True, is_leaf=_NESTED_TENSORS_AS_LISTS, sort=True))
):
if _is_tensor_collection(type(val)) and val.is_empty():
del self[key]
return self
Expand Down Expand Up @@ -9581,6 +9809,15 @@ def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False):
def namedtuple_to_dict(namedtuple_obj):
if is_namedtuple(namedtuple_obj):
namedtuple_obj = namedtuple_obj._asdict()

else:
from torch.return_types import cummax, cummin, max, min

if isinstance(namedtuple_obj, (min, cummin, max, cummax)):
namedtuple_obj = {
"values": namedtuple_obj.values,
"indices": namedtuple_obj.indices,
}
for key, value in namedtuple_obj.items():
if is_namedtuple(value):
namedtuple_obj[key] = namedtuple_to_dict(value)
Expand Down Expand Up @@ -10156,6 +10393,7 @@ def split_keys(
the arguments provided.
Args:
key_sets (sequence of Dict[in_key, out_key] or list of keys): the various splits.
inplace (bool, optional): if ``True``, the keys are removed from ``self``
in-place. Defaults to ``False``.
strict (bool, optional): if ``True``, an exception is raised when a key
Expand Down Expand Up @@ -10195,10 +10433,12 @@ def split_keys(
keys_to_del = set()
for key_set in key_sets:
outs.append(self.empty(recurse=reproduce_struct))
if not isinstance(key_set, dict):
key_set = {key: key for key in key_set}
for key in key_set:
val = last_out.pop(key, default)
if val is not None:
outs[-1].set(key, val)
outs[-1].set(key_set[key], val)
if inplace:
keys_to_del.add(key)
if inplace:
Expand Down
39 changes: 39 additions & 0 deletions tensordict/return_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from tensordict.tensorclass import tensorclass
from tensordict.tensordict import TensorDict


@tensorclass
class min:
"""A `min` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.min` operations."""

vals: TensorDict
indices: TensorDict


@tensorclass
class max:
"""A `max` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.max` operations."""

vals: TensorDict
indices: TensorDict


@tensorclass
class cummin:
"""A `cummin` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummin` operations."""

vals: TensorDict
indices: TensorDict


@tensorclass
class cummax:
"""A `cummax` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummax` operations."""

vals: TensorDict
indices: TensorDict
Loading

0 comments on commit 9ca4207

Please sign in to comment.