diff --git a/tensordict/base.py b/tensordict/base.py index 8c3a7a3ba..ef560b7fc 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1647,6 +1647,9 @@ def to_module( ): """Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively. + ``to_module`` can also be used a context manager to temporarily populate a module with a collection of + parameters/buffers (see example below). + Args: module (nn.Module): a module to write the parameters into. @@ -1670,9 +1673,27 @@ def to_module( ... decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4), ... num_layers=1) >>> params = TensorDict.from_module(module) - >>> params.zero_() + >>> params.data.zero_() >>> params.to_module(module) >>> assert (module.layers[0].linear1.weight == 0).all() + + Using a tensordict as a context manager can be useful to make functional calls: + Examples: + >>> from tensordict import from_module + >>> module = nn.TransformerDecoder( + ... decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4), + ... num_layers=1) + >>> params = TensorDict.from_module(module) + >>> params = params.data * 0 # Use TensorDictParams to remake these tensors regular nn.Parameter instances + >>> with params.to_module(module): + ... # Call the module with zeroed params + ... y = module(*inputs) + >>> # The module is repopulated with its original params + >>> assert (TensorDict.from_module(module) != 0).any() + + Returns: + A tensordict containing the values from the module if ``return_swap`` is ``True``, ``None`` otherwise. + """ if memo is not None: raise RuntimeError("memo cannot be passed to the public to_module anymore.")