From 330fb5ea8bf20d7cfbdf71c512ed923756c35a7b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 29 Nov 2024 18:05:41 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- docs/source/reference/nn.rst | 1 + tensordict/nn/__init__.py | 1 + tensordict/nn/common.py | 41 ++++++++++++++++++++++++++++++++---- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/docs/source/reference/nn.rst b/docs/source/reference/nn.rst index 7000ade01..cb6fb1739 100644 --- a/docs/source/reference/nn.rst +++ b/docs/source/reference/nn.rst @@ -197,6 +197,7 @@ to build distributions from network outputs and get summary statistics or sample TensorDictSequential TensorDictModuleWrapper CudaGraphModule + WrapModule Ensembles --------- diff --git a/tensordict/nn/__init__.py b/tensordict/nn/__init__.py index 55590889a..e930ac75d 100644 --- a/tensordict/nn/__init__.py +++ b/tensordict/nn/__init__.py @@ -9,6 +9,7 @@ TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, + WrapModule, ) from tensordict.nn.distributions import ( AddStateIndependentNormalScale, diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 395141c0a..7a1a7a22b 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1278,12 +1278,45 @@ def forward(self, *args: Any, **kwargs: Any) -> TensorDictBase: class WrapModule(TensorDictModuleBase): + """A wrapper around any callable that processes TensorDict instances. + + This wrapper is useful when building :class:`~tensordict.nn.TensorDictSequential` stacks and when a transform + requires the entire TensorDict instance to be visible. + + Args: + func (Callable[[TensorDictBase], TensorDictBase]): A callable function that takes in a TensorDictBase instance + and returns a transformed TensorDictBase instance. + + Keyword Args: + inplace (bool, optional): If ``True``, the input TensorDict will be modified in-place. Otherwise, a new TensorDict + will be returned (if the function does not modify it in-place and returns it). Defaults to ``False``. + + Examples: + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule + >>> seq = Seq( + ... Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]), + ... WrapModule(lambda td: td.reshape(-1)), + ... ) + >>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4]) + >>> td = Seq(td) + >>> assert td.shape == (12,) + >>> assert (td["y"] == 2).all() + >>> assert td["y"].shape == (12, 5) + + """ + in_keys = [] out_keys = [] - def __init__(self, func): - self.func = func + def __init__( + self, func: Callable[[TensorDictBase], TensorDictBase], *, inplace: bool = False + ) -> None: super().__init__() + self.func = func + self.inplace = inplace - def forward(self, data): - return self.func(data) + def forward(self, data: TensorDictBase) -> TensorDictBase: + result = self.func(data) + if self.inplace and result is not data: + return data.update(result) + return result