Skip to content

Commit

Permalink
[Feature] TensorDict.softmax
Browse files Browse the repository at this point in the history
ghstack-source-id: a88bebc23e6aaa02ec297db72dbda68ec9628ce7
Pull Request resolved: #1163
  • Loading branch information
vmoens committed Jan 7, 2025
1 parent 748db34 commit d8dceaf
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
27 changes: 26 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9465,7 +9465,7 @@ def logsumexp(self, dim=None, keepdim=False, *, out=None): # noqa: D417
Otherwise, ``dim`` is squeezed (see :func:`~torch.squeeze`), resulting in the output tensor having 1 (or len(dim)) fewer dimension(s).

Args:
dim (int or tuple of ints): the dimension or dimensions to reduce. If ``None``, all batch dimensions of the
dim (int or tuple of ints, optional): the dimension or dimensions to reduce. If ``None``, all batch dimensions of the
tensordict are reduced.
keepdim (bool): whether the output tensordict has dim retained or not.

Expand Down Expand Up @@ -9509,6 +9509,31 @@ def logsumexp(self, dim=None, keepdim=False, *, out=None): # noqa: D417
batch_size=batch_size,
)

def softmax(self, dim: int, dtype: torch.dtype | None = None): # noqa: D417
"""Apply a softmax function to the tensordict elements.

Args:
dim (int or tuple of ints): A tensordict dimension along which softmax will be computed.
dtype (torch.dtype, optional): the desired data type of returned tensor.
If specified, the input tensor is cast to dtype before the operation is performed.
This is useful for preventing data type overflows.

"""
if isinstance(dim, int):
if dim < 0:
new_dim = self.ndim + dim
else:
new_dim = dim
else:
raise ValueError(f"Expected dim of type int, got {type(dim)}.")
if (new_dim < 0) or (new_dim >= self.ndim):
raise ValueError(
f"The dimension {dim} is incompatible with a tensordict with batch_size {self.batch_size}."
)
return self._fast_apply(
lambda x: torch.softmax(x, dim=new_dim, dtype=dtype),
)

def log10(self) -> T:
"""Computes the :meth:`~torch.log10` value of each element of the TensorDict."""
keys, vals = self._items_list(True, True)
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def __subclasscheck__(self, subclass):
"sin_",
"sinh",
"sinh_",
"softmax",
"split",
"sqrt",
"sqrt_",
Expand Down
12 changes: 12 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6303,6 +6303,18 @@ def test_shape(self, td_name, device):
td = getattr(self, td_name)(device)
assert td.shape == td.batch_size

@pytest.mark.parametrize("dim", [0, -1, 3])
def test_softmax(self, td_name, device, dim):
td = getattr(self, td_name)(device)
if td_name in ("sub_td", "sub_td2"):
return
with td.unlock_():
td.apply(lambda x: x.float(), out=td)
tds = td.softmax(dim=dim)
assert tds.shape == td.shape
tds._check_batch_size()
tds._check_device()

def test_sorted_keys(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down

0 comments on commit d8dceaf

Please sign in to comment.