From 1abb8f84b89824c55dd1fffa3139c9827363b1cc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 21 Jan 2025 09:56:23 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/nn/utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 60920f23a..217d363a5 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -450,7 +450,13 @@ def _generate_next_value_(name, start, count, last_values): return name.lower() -_composite_lp_aggregate = _ContextManager() +_composite_lp_aggregate = _ContextManager( + default=( + strtobool(os.getenv("COMPOSITE_LP_AGGREGATE")) + if os.getenv("COMPOSITE_LP_AGGREGATE") is not None + else None + ) +) def composite_lp_aggregate(nowarn: bool = False) -> bool | None: @@ -467,9 +473,9 @@ def composite_lp_aggregate(nowarn: bool = False) -> bool | None: if not nowarn: warnings.warn( "Composite log-prob aggregation wasn't defined explicitly and ``composite_lp_aggregate()`` will " - "currently return ``True``. However, from v0.9, this behaviour will change and ``composite_lp_aggregate`` will " + "currently return ``True``. However, from v0.9, this behavior will change and ``composite_lp_aggregate`` will " "return ``False``. Please change your code accordingly by specifying the aggregation strategy via " - "`tensordict.nn.set_composite_lp_aggregate`.", + "`tensordict.nn.set_composite_lp_aggregate` or via the `COMPOSITE_LP_AGGREGATE` environment variable.", category=DeprecationWarning, ) return True @@ -483,6 +489,8 @@ class set_composite_lp_aggregate(_DecoratorContextManager): will be summed into a single tensor with the shape of the root tensordict. This behaviour is being deprecated in favor of non-aggregated log-probs, which offer more flexibility and a somewhat more natural API (tensordict samples, tensordict log-probs, tensordict entropies). + The value of composite_lp_aggregate can also be controlled through the `COMPOSITE_LP_AGGREGATE` environment variable. + Example: >>> _ = torch.manual_seed(0) >>> from tensordict import TensorDict