Skip to content

Commit

Permalink
[Feature] Optional in_keys for WrapModule
Browse files Browse the repository at this point in the history
ghstack-source-id: a18dd5dff39937b027243fcebc6ef449b547e0b0
Pull Request resolved: #1145
  • Loading branch information
vmoens committed Dec 19, 2024
1 parent eaafc18 commit 2d37d92
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,12 @@ class WrapModule(TensorDictModuleBase):
Keyword Args:
inplace (bool, optional): If ``True``, the input TensorDict will be modified in-place. Otherwise, a new TensorDict
will be returned (if the function does not modify it in-place and returns it). Defaults to ``False``.
in_keys (list of NestedKey, optional): if provided, indicates what entries are read by the module.
This will not be checked and is provided just for the purpose of informing :class:`~tensordict.nn.TensorDictSequential`
about the input keys of the wrapped module. Defaults to `[]`.
out_keys (list of NestedKey, optional): if provided, indicates what entries are written by the module.
This will not be checked and is provided just for the purpose of informing :class:`~tensordict.nn.TensorDictSequential`
about the output keys of the wrapped module. Defaults to `[]`.
Examples:
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule
Expand All @@ -1320,11 +1326,20 @@ class WrapModule(TensorDictModuleBase):
out_keys = []

def __init__(
self, func: Callable[[TensorDictBase], TensorDictBase], *, inplace: bool = False
self,
func: Callable[[TensorDictBase], TensorDictBase],
*,
inplace: bool = False,
in_keys: List[NestedKey] | None = None,
out_keys: List[NestedKey] | None = None,
) -> None:
super().__init__()
self.func = func
self.inplace = inplace
if in_keys is not None:
self.in_keys = in_keys
if out_keys is not None:
self.out_keys = out_keys

def forward(self, data: TensorDictBase) -> TensorDictBase:
result = self.func(data)
Expand Down

0 comments on commit 2d37d92

Please sign in to comment.