From 212754718c306075dc98f69482627c1d42d1858d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 24 Oct 2024 11:47:22 -0700 Subject: [PATCH] [Feature] min, amin, max, amax, cummin, cummax ghstack-source-id: 9873c08f98e84b372c6f701a3326e900454dc1d0 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1057 --- tensordict/_lazy.py | 3 +- tensordict/_td.py | 14 +- tensordict/base.py | 262 ++++++++++++++++++++++++++++++++++++- tensordict/return_types.py | 39 ++++++ tensordict/tensorclass.py | 6 +- test/test_tensordict.py | 97 ++++++++++++++ 6 files changed, 414 insertions(+), 7 deletions(-) create mode 100644 tensordict/return_types.py diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 88145d71e..d2daf8b23 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1826,7 +1826,8 @@ def _apply_nest( arg for arg in (batch_size, device, names, constructor_kwargs) ): raise ValueError( - "Cannot pass other arguments to LazyStackedTensorDict.apply when inplace=True." + "Cannot pass other arguments to LazyStackedTensorDict.apply when inplace=True. Got args " + f"batch_size={batch_size}, device={device}, names={names}, constructor_kwargs={constructor_kwargs}" ) if out is not None: if not isinstance(out, LazyStackedTensorDict): diff --git a/tensordict/_td.py b/tensordict/_td.py index 4387839b5..f5960fbb6 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -956,6 +956,9 @@ def _cast_reduction( keepdim=NO_DEFAULT, tuple_ok=True, further_reduce: bool, + values_only: bool = True, + call_on_nested: bool = True, + batch_size=None, **kwargs, ): if further_reduce: @@ -1015,9 +1018,16 @@ def reduction(val): result = getattr(val, reduction_name)( **kwargs, ) + if isinstance(result, tuple): + if values_only: + result = result.values + else: + return TensorDict.from_namedtuple(result) return result - if dim not in (None, NO_DEFAULT): + if batch_size is not None: + pass + elif dim is not None and dim is not NO_DEFAULT: if not keepdim: if isinstance(dim, tuple): batch_size = [ @@ -1043,7 +1053,7 @@ def reduction(val): return self._fast_apply( reduction, - call_on_nested=True, + call_on_nested=call_on_nested, batch_size=torch.Size(batch_size), device=self.device, names=names, diff --git a/tensordict/base.py b/tensordict/base.py index 358cae1b1..73bf8b896 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -120,7 +120,7 @@ class _NoDefault(enum.IntEnum): NO_DEFAULT = _NoDefault.ZERO - +assert not NO_DEFAULT T = TypeVar("T", bound="TensorDictBase") @@ -575,6 +575,248 @@ def isreal(self) -> T: propagate_lock=True, ) + def amin( + self, + dim: int | NO_DEFAULT = NO_DEFAULT, + keepdim: bool = False, + *, + reduce: bool | None = None, + ) -> TensorDictBase | torch.Tensor: # noqa: D417 + """Returns the minimum values of all elements in the input tensordict. + + Same as :meth:`~.min` with ``return_indices=False``. + """ + return self._cast_reduction( + reduction_name="amin", + dim=dim, + keepdim=keepdim, + further_reduce=reduce, + tuple_ok=False, + values_only=True, + call_on_nested=False, + ) + + def min( + self, + dim: int | NO_DEFAULT = NO_DEFAULT, + keepdim: bool = False, + *, + reduce: bool | None = None, + return_indices: bool = True, + ) -> TensorDictBase | torch.Tensor: # noqa: D417 + """Returns the minimum values of all elements in the input tensordict. + + Args: + dim (int, optional): if ``None``, returns a dimensionless + tensordict containing the min value of all leaves (if this can be computed). + If integer, `min` is called upon the dimension specified if + and only if this dimension is compatible with the tensordict + shape. + keepdim (bool): whether the output tensor has dim retained or not. + + Keyword Args: + reduce (bool, optional): if ``True``, the reduciton will occur across all TensorDict values + and a single reduced tensor will be returned. + Defaults to ``False``. + return_argmins (bool, optional): :func:`~torch.min` returns a named tuple with values and indices + when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass + with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``. + + """ + result = self._cast_reduction( + reduction_name="min", + dim=dim, + keepdim=keepdim, + further_reduce=reduce, + tuple_ok=False, + values_only=not return_indices, + call_on_nested=False, + ) + if dim is not NO_DEFAULT and return_indices: + # Split the tensordict + from .return_types import min + + values_dict = {} + indices_dict = {} + for key in result.keys(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + if key[-1] == "values": + values_dict[key] = key[:-1] + else: + indices_dict[key] = key[:-1] + return min( + *result.split_keys(values_dict, indices_dict), + batch_size=result.batch_size, + ) + return result + + def amax( + self, + dim: int | NO_DEFAULT = NO_DEFAULT, + keepdim: bool = False, + *, + reduce: bool | None = None, + ) -> TensorDictBase | torch.Tensor: # noqa: D417 + """Returns the maximum values of all elements in the input tensordict. + + Same as :meth:`~.max` with ``return_indices=False``. + """ + return self._cast_reduction( + reduction_name="amax", + dim=dim, + keepdim=keepdim, + further_reduce=reduce, + tuple_ok=False, + values_only=True, + call_on_nested=False, + ) + + def max( + self, + dim: int | NO_DEFAULT = NO_DEFAULT, + keepdim: bool = False, + *, + reduce: bool | None = None, + return_indices: bool = True, + ) -> TensorDictBase | torch.Tensor: # noqa: D417 + """Returns the maximum values of all elements in the input tensordict. + + Args: + dim (int, optional): if ``None``, returns a dimensionless + tensordict containing the max value of all leaves (if this can be computed). + If integer, `max` is called upon the dimension specified if + and only if this dimension is compatible with the tensordict + shape. + keepdim (bool): whether the output tensor has dim retained or not. + + Keyword Args: + reduce (bool, optional): if ``True``, the reduciton will occur across all TensorDict values + and a single reduced tensor will be returned. + Defaults to ``False``. + return_argmins (bool, optional): :func:`~torch.max` returns a named tuple with values and indices + when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass + with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``. + + """ + result = self._cast_reduction( + reduction_name="max", + dim=dim, + keepdim=keepdim, + further_reduce=reduce, + tuple_ok=False, + values_only=not return_indices, + call_on_nested=False, + ) + if dim is not NO_DEFAULT and return_indices: + # Split the tensordict + from .return_types import max + + values_dict = {} + indices_dict = {} + for key in result.keys(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + if key[-1] == "values": + values_dict[key] = key[:-1] + else: + indices_dict[key] = key[:-1] + return max( + *result.split_keys(values_dict, indices_dict), + batch_size=result.batch_size, + ) + return result + + def cummin( + self, + dim: int, + *, + reduce: bool | None = None, + return_indices: bool = True, + ) -> TensorDictBase | torch.Tensor: # noqa: D417 + """Returns the cumulative minimum values of all elements in the input tensordict. + + Args: + dim (int): integer representing the dimension along which to perform the cummin operation. + + Keyword Args: + reduce (bool, optional): if ``True``, the reduciton will occur across all TensorDict values + and a single reduced tensor will be returned. + Defaults to ``False``. + return_argmins (bool, optional): :func:`~torch.cummin` returns a named tuple with values and indices + when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass + with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``. + + """ + result = self._cast_reduction( + reduction_name="cummin", + dim=dim, + further_reduce=reduce, + tuple_ok=False, + values_only=not return_indices, + call_on_nested=False, + batch_size=self.batch_size, + ) + if dim is not NO_DEFAULT and return_indices: + # Split the tensordict + from .return_types import cummin + + values_dict = {} + indices_dict = {} + for key in result.keys(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + if key[-1] == "values": + values_dict[key] = key[:-1] + else: + indices_dict[key] = key[:-1] + return cummin( + *result.split_keys(values_dict, indices_dict), + batch_size=result.batch_size, + ) + return result + + def cummax( + self, + dim: int, + *, + reduce: bool | None = None, + return_indices: bool = True, + ) -> TensorDictBase | torch.Tensor: # noqa: D417 + """Returns the cumulative maximum values of all elements in the input tensordict. + + Args: + dim (int): integer representing the dimension along which to perform the cummax operation. + + Keyword Args: + reduce (bool, optional): if ``True``, the reduciton will occur across all TensorDict values + and a single reduced tensor will be returned. + Defaults to ``False``. + return_argmins (bool, optional): :func:`~torch.cummax` returns a named tuple with values and indices + when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass + with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``. + + """ + result = self._cast_reduction( + reduction_name="cummax", + dim=dim, + further_reduce=reduce, + tuple_ok=False, + values_only=not return_indices, + call_on_nested=False, + batch_size=self.batch_size, + ) + if dim is not NO_DEFAULT and return_indices: + # Split the tensordict + from .return_types import cummax + + values_dict = {} + indices_dict = {} + for key in result.keys(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + if key[-1] == "values": + values_dict[key] = key[:-1] + else: + indices_dict[key] = key[:-1] + return cummax( + *result.split_keys(values_dict, indices_dict), + batch_size=result.batch_size, + ) + return result + def mean( self, dim: int | Tuple[int] = NO_DEFAULT, @@ -4840,7 +5082,9 @@ def _filter(x): def filter_empty_(self): """Filters out all empty tensordicts in-place.""" - for key, val in list(self.items(True, is_leaf=_NESTED_TENSORS_AS_LISTS)): + for key, val in reversed( + list(self.items(True, is_leaf=_NESTED_TENSORS_AS_LISTS, sort=True)) + ): if _is_tensor_collection(type(val)) and val.is_empty(): del self[key] return self @@ -9581,6 +9825,15 @@ def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): def namedtuple_to_dict(namedtuple_obj): if is_namedtuple(namedtuple_obj): namedtuple_obj = namedtuple_obj._asdict() + + else: + from torch.return_types import cummax, cummin, max, min + + if isinstance(namedtuple_obj, (min, cummin, max, cummax)): + namedtuple_obj = { + "values": namedtuple_obj.values, + "indices": namedtuple_obj.indices, + } for key, value in namedtuple_obj.items(): if is_namedtuple(value): namedtuple_obj[key] = namedtuple_to_dict(value) @@ -10156,6 +10409,7 @@ def split_keys( the arguments provided. Args: + key_sets (sequence of Dict[in_key, out_key] or list of keys): the various splits. inplace (bool, optional): if ``True``, the keys are removed from ``self`` in-place. Defaults to ``False``. strict (bool, optional): if ``True``, an exception is raised when a key @@ -10195,10 +10449,12 @@ def split_keys( keys_to_del = set() for key_set in key_sets: outs.append(self.empty(recurse=reproduce_struct)) + if not isinstance(key_set, dict): + key_set = {key: key for key in key_set} for key in key_set: val = last_out.pop(key, default) if val is not None: - outs[-1].set(key, val) + outs[-1].set(key_set[key], val) if inplace: keys_to_del.add(key) if inplace: diff --git a/tensordict/return_types.py b/tensordict/return_types.py new file mode 100644 index 000000000..0c6668305 --- /dev/null +++ b/tensordict/return_types.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from tensordict.tensorclass import tensorclass +from tensordict.tensordict import TensorDict + + +@tensorclass +class min: + """A `min` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.min` operations.""" + + vals: TensorDict + indices: TensorDict + + +@tensorclass +class max: + """A `max` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.max` operations.""" + + vals: TensorDict + indices: TensorDict + + +@tensorclass +class cummin: + """A `cummin` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummin` operations.""" + + vals: TensorDict + indices: TensorDict + + +@tensorclass +class cummax: + """A `cummax` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummax` operations.""" + + vals: TensorDict + indices: TensorDict diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 8906eefd4..69dd3fffe 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -210,6 +210,8 @@ def __subclasscheck__(self, subclass): "cosh_", "cpu", "cuda", + "cummax", + "cummin", "densify", "div", "div_", @@ -253,9 +255,11 @@ def __subclasscheck__(self, subclass): "map_iter", "masked_fill", "masked_fill_", + "max", "maximum", "maximum_", "mean", + "min", "minimum", "minimum_", "mul", @@ -279,9 +283,9 @@ def __subclasscheck__(self, subclass): "reciprocal", "reciprocal_", "refine_names", - "requires_grad_", "rename_", # TODO: must be specialized "replace", + "requires_grad_", "reshape", "round", "round_", diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 0f1f65b5d..3c9636f32 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -4749,6 +4749,103 @@ def test_memmap_threads(self, td_name, device, use_dir, tmpdir, num_threads): ) assert_allclose_td(td.cpu().detach(), tdfuture.result()) + @pytest.mark.parametrize( + "dim, keepdim, return_indices", + [ + [None, False, False], + [0, False, False], + [0, True, False], + [0, False, True], + [0, True, True], + [1, False, False], + [1, True, False], + [1, False, True], + [1, True, True], + [-1, False, False], + [-1, True, False], + [-1, False, True], + [-1, True, True], + ], + ) + def test_min_max_cummin_cummax(self, td_name, device, dim, keepdim, return_indices): + import tensordict.return_types as return_types + + td = getattr(self, td_name)(device) + # min + if dim is not None: + kwargs = {"dim": dim, "keepdim": keepdim, "return_indices": return_indices} + else: + kwargs = {} + r = td.min(**kwargs) + if not return_indices and dim is not None: + assert_allclose_td(r, td.amin(dim=dim, keepdim=keepdim)) + if return_indices: + assert is_tensorclass(r) + assert isinstance(r, return_types.min) + assert not r.vals.is_empty() + assert not r.indices.is_empty() + else: + assert not is_tensorclass(r) + if dim is None: + assert r.batch_size == () + elif keepdim: + s = list(td.batch_size) + s[dim] = 1 + assert r.batch_size == tuple(s) + else: + s = list(td.batch_size) + s.pop(dim) + assert r.batch_size == tuple(s) + + r = td.max(**kwargs) + if not return_indices and dim is not None: + assert_allclose_td(r, td.amax(dim=dim, keepdim=keepdim)) + if return_indices: + assert is_tensorclass(r) + assert isinstance(r, return_types.max) + assert not r.vals.is_empty() + assert not r.indices.is_empty() + else: + assert not is_tensorclass(r) + if dim is None: + assert r.batch_size == () + elif keepdim: + s = list(td.batch_size) + s[dim] = 1 + assert r.batch_size == tuple(s) + else: + s = list(td.batch_size) + s.pop(dim) + assert r.batch_size == tuple(s) + if dim is None: + return + kwargs.pop("keepdim") + r = td.cummin(**kwargs) + if return_indices: + assert is_tensorclass(r) + assert isinstance(r, return_types.cummin) + assert not r.vals.is_empty() + assert not r.indices.is_empty() + else: + assert not is_tensorclass(r) + if dim is None: + assert r.batch_size == () + else: + assert r.batch_size == td.batch_size + + r = td.cummax(**kwargs) + if return_indices: + assert is_tensorclass(r) + assert isinstance(r, return_types.cummax) + assert not r.vals.is_empty() + assert not r.indices.is_empty() + else: + assert not is_tensorclass(r) + if dim is None: + assert r.batch_size == () + else: + assert r.batch_size == td.batch_size + @pytest.mark.parametrize("inplace", [False, True]) def test_named_apply(self, td_name, device, inplace): td = getattr(self, td_name)(device)