Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 21, 2025
1 parent 183b288 commit 1abb8f8
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,13 @@ def _generate_next_value_(name, start, count, last_values):
return name.lower()


_composite_lp_aggregate = _ContextManager()
_composite_lp_aggregate = _ContextManager(
default=(
strtobool(os.getenv("COMPOSITE_LP_AGGREGATE"))
if os.getenv("COMPOSITE_LP_AGGREGATE") is not None
else None
)
)


def composite_lp_aggregate(nowarn: bool = False) -> bool | None:
Expand All @@ -467,9 +473,9 @@ def composite_lp_aggregate(nowarn: bool = False) -> bool | None:
if not nowarn:
warnings.warn(
"Composite log-prob aggregation wasn't defined explicitly and ``composite_lp_aggregate()`` will "
"currently return ``True``. However, from v0.9, this behaviour will change and ``composite_lp_aggregate`` will "
"currently return ``True``. However, from v0.9, this behavior will change and ``composite_lp_aggregate`` will "
"return ``False``. Please change your code accordingly by specifying the aggregation strategy via "
"`tensordict.nn.set_composite_lp_aggregate`.",
"`tensordict.nn.set_composite_lp_aggregate` or via the `COMPOSITE_LP_AGGREGATE` environment variable.",
category=DeprecationWarning,
)
return True
Expand All @@ -483,6 +489,8 @@ class set_composite_lp_aggregate(_DecoratorContextManager):
will be summed into a single tensor with the shape of the root tensordict. This behaviour is being deprecated in favor of
non-aggregated log-probs, which offer more flexibility and a somewhat more natural API (tensordict samples, tensordict log-probs, tensordict entropies).
The value of composite_lp_aggregate can also be controlled through the `COMPOSITE_LP_AGGREGATE` environment variable.
Example:
>>> _ = torch.manual_seed(0)
>>> from tensordict import TensorDict
Expand Down

0 comments on commit 1abb8f8

Please sign in to comment.