Skip to content

Commit

Permalink
[Doc] Better docstring for to_module
Browse files Browse the repository at this point in the history
ghstack-source-id: 16cedee8c0d38da6f377a262d5d7478a66fce07f
Pull Request resolved: #1081
  • Loading branch information
vmoens committed Nov 7, 2024
1 parent 31c7330 commit 9607cf0
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,9 @@ def to_module(
):
"""Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively.
``to_module`` can also be used a context manager to temporarily populate a module with a collection of
parameters/buffers (see example below).
Args:
module (nn.Module): a module to write the parameters into.
Expand All @@ -1670,9 +1673,27 @@ def to_module(
... decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4),
... num_layers=1)
>>> params = TensorDict.from_module(module)
>>> params.zero_()
>>> params.data.zero_()
>>> params.to_module(module)
>>> assert (module.layers[0].linear1.weight == 0).all()
Using a tensordict as a context manager can be useful to make functional calls:
Examples:
>>> from tensordict import from_module
>>> module = nn.TransformerDecoder(
... decoder_layer=nn.TransformerDecoderLayer(nhead=4, d_model=4),
... num_layers=1)
>>> params = TensorDict.from_module(module)
>>> params = params.data * 0 # Use TensorDictParams to remake these tensors regular nn.Parameter instances
>>> with params.to_module(module):
... # Call the module with zeroed params
... y = module(*inputs)
>>> # The module is repopulated with its original params
>>> assert (TensorDict.from_module(module) != 0).any()
Returns:
A tensordict containing the values from the module if ``return_swap`` is ``True``, ``None`` otherwise.
"""
if memo is not None:
raise RuntimeError("memo cannot be passed to the public to_module anymore.")
Expand Down

0 comments on commit 9607cf0

Please sign in to comment.