diff --git a/tensordict/base.py b/tensordict/base.py index 0fe2eb737..4efb362d8 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9465,7 +9465,7 @@ def logsumexp(self, dim=None, keepdim=False, *, out=None): # noqa: D417 Otherwise, ``dim`` is squeezed (see :func:`~torch.squeeze`), resulting in the output tensor having 1 (or len(dim)) fewer dimension(s). Args: - dim (int or tuple of ints): the dimension or dimensions to reduce. If ``None``, all batch dimensions of the + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. If ``None``, all batch dimensions of the tensordict are reduced. keepdim (bool): whether the output tensordict has dim retained or not. @@ -9509,6 +9509,31 @@ def logsumexp(self, dim=None, keepdim=False, *, out=None): # noqa: D417 batch_size=batch_size, ) + def softmax(self, dim: int, dtype: torch.dtype | None = None): # noqa: D417 + """Apply a softmax function to the tensordict elements. + + Args: + dim (int or tuple of ints): A tensordict dimension along which softmax will be computed. + dtype (torch.dtype, optional): the desired data type of returned tensor. + If specified, the input tensor is cast to dtype before the operation is performed. + This is useful for preventing data type overflows. + + """ + if isinstance(dim, int): + if dim < 0: + new_dim = self.ndim + dim + else: + new_dim = dim + else: + raise ValueError(f"Expected dim of type int, got {type(dim)}.") + if (new_dim < 0) or (new_dim >= self.ndim): + raise ValueError( + f"The dimension {dim} is incompatible with a tensordict with batch_size {self.batch_size}." + ) + return self._fast_apply( + lambda x: torch.softmax(x, dim=new_dim, dtype=dtype), + ) + def log10(self) -> T: """Computes the :meth:`~torch.log10` value of each element of the TensorDict.""" keys, vals = self._items_list(True, True) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index e1ba45855..141e47aaf 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -354,6 +354,7 @@ def __subclasscheck__(self, subclass): "sin_", "sinh", "sinh_", + "softmax", "split", "sqrt", "sqrt_", diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 54d7541ba..22582441b 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6303,6 +6303,18 @@ def test_shape(self, td_name, device): td = getattr(self, td_name)(device) assert td.shape == td.batch_size + @pytest.mark.parametrize("dim", [0, -1, 3]) + def test_softmax(self, td_name, device, dim): + td = getattr(self, td_name)(device) + if td_name in ("sub_td", "sub_td2"): + return + with td.unlock_(): + td.apply(lambda x: x.float(), out=td) + tds = td.softmax(dim=dim) + assert tds.shape == td.shape + tds._check_batch_size() + tds._check_device() + def test_sorted_keys(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device)