Skip to content

Commit

Permalink
[Feature] TensorDict.logsumexp
Browse files Browse the repository at this point in the history
ghstack-source-id: 84148ad9c701029db6d02dfb84ddb0a9b26c9ab7
Pull Request resolved: #1162
  • Loading branch information
vmoens committed Jan 7, 2025
1 parent ecb692e commit 748db34
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 3 deletions.
53 changes: 52 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6378,7 +6378,7 @@ def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType:
# We raise an exception AND a warning because we want the user to know that this exception will
# not be raised in the future
warnings.warn(
"The entry you have queried with `get` is not present in the tensordict. "
f"The entry ({key}) you have queried with `get` is not present in the tensordict. "
"Currently, this raises an exception. "
"To align with `dict.get`, this behaviour will be changed in v0.7 and a `None` value will "
"be returned instead (no error will be raised). "
Expand Down Expand Up @@ -9458,6 +9458,57 @@ def log_(self) -> T:
torch._foreach_log_(self._values_list(True, True))
return self

def logsumexp(self, dim=None, keepdim=False, *, out=None): # noqa: D417
"""Returns the log of summed exponentials of each row of the input tensordict in the given dimension ``dim``. The computation is numerically stabilized.

If keepdim is ``True``, the output tensor is of the same size as input except in the dimension(s) ``dim`` where it is of size ``1``.
Otherwise, ``dim`` is squeezed (see :func:`~torch.squeeze`), resulting in the output tensor having 1 (or len(dim)) fewer dimension(s).

Args:
dim (int or tuple of ints): the dimension or dimensions to reduce. If ``None``, all batch dimensions of the
tensordict are reduced.
keepdim (bool): whether the output tensordict has dim retained or not.

Keyword Args:
out (TensorDictBase, optional): the output tensordict.

"""
if isinstance(dim, int):
if dim < 0:
new_dim = (self.ndim + dim,)
else:
new_dim = (dim,)
elif dim is not None:
new_dim = tuple(self.ndim + _dim if _dim < 0 else _dim for _dim in dim)
else:
new_dim = tuple(range(self.ndim))
if new_dim is not None and any((d < 0) or (d >= self.ndim) for d in new_dim):
raise ValueError(
f"The dimension {dim} is incompatible with a tensordict with batch_size {self.batch_size}."
)
batch_size = self.batch_size
if keepdim:
batch_size = torch.Size(
[b if i not in new_dim else 1 for i, b in enumerate(batch_size)]
)
else:
batch_size = torch.Size(
[b for i, b in enumerate(batch_size) if i not in new_dim]
)
if out is not None:
result = self._fast_apply(
lambda x, y: torch.logsumexp(x, dim=new_dim, keepdim=keepdim, out=y),
out,
default=None,
batch_size=batch_size,
)
return out.update(result)

return self._fast_apply(
lambda x: torch.logsumexp(x, dim=new_dim, keepdim=keepdim),
batch_size=batch_size,
)

def log10(self) -> T:
"""Computes the :meth:`~torch.log10` value of each element of the TensorDict."""
keys, vals = self._items_list(True, True)
Expand Down
6 changes: 4 additions & 2 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def __subclasscheck__(self, subclass):
"log2",
"log2_",
"log_",
"logical_and" "map",
"logical_and",
"logsumexp",
"map",
"map_iter",
"masked_fill",
"masked_fill_",
Expand Down Expand Up @@ -1856,7 +1858,7 @@ def _to_tensordict(self, *, retain_none: bool | None = None) -> TensorDict:
retain_none = True
warnings.warn(
"retain_none was not specified and a None value was encountered in the tensorclass. "
"As of now, the None will be written in the tensordict but this default behaviour will change"
"As of now, the None will be written in the tensordict but this default behaviour will change "
"in v0.8. To disable this warning, specify the value of retain_none."
)
if retain_none:
Expand Down
38 changes: 38 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4792,6 +4792,44 @@ def test_lock_write(self, td_name, device):
td_set = td
td_set.set_(key, item)

@pytest.mark.parametrize("has_out", [False, "complete", "empty"])
@pytest.mark.parametrize("keepdim", [True, False])
@pytest.mark.parametrize("dim", [1, (1,), (1, -1)])
def test_logsumexp(self, td_name, device, has_out, keepdim, dim):
td = getattr(self, td_name)(device)
if not has_out:
out = None
elif has_out == "complete":
out = (
td.to_tensordict(retain_none=False)
.detach()
.logsumexp(dim=dim, keepdim=keepdim)
)
if td.requires_grad:
td = td.detach()
else:
out = (
td.to_tensordict(retain_none=False)
.detach()
.logsumexp(dim=dim, keepdim=keepdim)
.empty()
)
if td.requires_grad:
td = td.detach()
if out is not None:
out_c = out.copy()
tdlse = td.logsumexp(dim=dim, out=out, keepdim=keepdim)
assert tdlse.batch_size != td.batch_size
if out is not None:

def check(x, y):
if y is not None:
assert x is y

assert tdlse is out
tdlse.apply(check, out_c, default=None)
tdlse._check_batch_size()

def test_masked_fill(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down

0 comments on commit 748db34

Please sign in to comment.