Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 29, 2024
1 parent 0e18600 commit 330fb5e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ to build distributions from network outputs and get summary statistics or sample
TensorDictSequential
TensorDictModuleWrapper
CudaGraphModule
WrapModule

Ensembles
---------
Expand Down
1 change: 1 addition & 0 deletions tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
TensorDictModule,
TensorDictModuleBase,
TensorDictModuleWrapper,
WrapModule,
)
from tensordict.nn.distributions import (
AddStateIndependentNormalScale,
Expand Down
41 changes: 37 additions & 4 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,12 +1278,45 @@ def forward(self, *args: Any, **kwargs: Any) -> TensorDictBase:


class WrapModule(TensorDictModuleBase):
"""A wrapper around any callable that processes TensorDict instances.
This wrapper is useful when building :class:`~tensordict.nn.TensorDictSequential` stacks and when a transform
requires the entire TensorDict instance to be visible.
Args:
func (Callable[[TensorDictBase], TensorDictBase]): A callable function that takes in a TensorDictBase instance
and returns a transformed TensorDictBase instance.
Keyword Args:
inplace (bool, optional): If ``True``, the input TensorDict will be modified in-place. Otherwise, a new TensorDict
will be returned (if the function does not modify it in-place and returns it). Defaults to ``False``.
Examples:
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule
>>> seq = Seq(
... Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]),
... WrapModule(lambda td: td.reshape(-1)),
... )
>>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4])
>>> td = Seq(td)
>>> assert td.shape == (12,)
>>> assert (td["y"] == 2).all()
>>> assert td["y"].shape == (12, 5)
"""

in_keys = []
out_keys = []

def __init__(self, func):
self.func = func
def __init__(
self, func: Callable[[TensorDictBase], TensorDictBase], *, inplace: bool = False
) -> None:
super().__init__()
self.func = func
self.inplace = inplace

def forward(self, data):
return self.func(data)
def forward(self, data: TensorDictBase) -> TensorDictBase:
result = self.func(data)
if self.inplace and result is not data:
return data.update(result)
return result

0 comments on commit 330fb5e

Please sign in to comment.