Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 29, 2024
1 parent e871b7d commit 7016a89
Show file tree
Hide file tree
Showing 4 changed files with 686 additions and 86 deletions.
26 changes: 7 additions & 19 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,6 @@ def from_distributions(
self.inplace = inplace
return self

@property
def aggregate_probabilities(self):
aggregate_probabilities = self._aggregate_probabilities
if aggregate_probabilities is None:
warnings.warn(
"The default value of `aggregate_probabilities` will change from `False` to `True` in v0.7. "
"Please pass this value explicitly to avoid this warning.",
FutureWarning,
)
aggregate_probabilities = self._aggregate_probabilities = False
return aggregate_probabilities

@aggregate_probabilities.setter
def aggregate_probabilities(self, value):
self._aggregate_probabilities = value

def sample(self, shape=None) -> TensorDictBase:
if shape is None:
shape = torch.Size([])
Expand Down Expand Up @@ -337,7 +321,7 @@ def log_prob(
aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities``
from the class.
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default).
Has no effect if ``aggregate_probabilities`` is set to ``True``.
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
Expand All @@ -356,6 +340,8 @@ def log_prob(
"""
if aggregate_probabilities is None:
aggregate_probabilities = self.aggregate_probabilities
if aggregate_probabilities is None:
aggregate_probabilities = False
if not aggregate_probabilities:
return self.log_prob_composite(
sample, include_sum=include_sum, inplace=inplace
Expand All @@ -382,7 +368,7 @@ def log_prob_composite(
Keyword Args:
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default).
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
Expand Down Expand Up @@ -451,7 +437,7 @@ def entropy(
setting from the class. Determines whether to return a single summed entropy tensor or a TensorDict
with individual entropies. Defaults to ``False`` if not set in the class.
include_sum (bool, optional): Whether to include the summed entropy in the output TensorDict.
Defaults to `self.inplace`, which is set through the class constructor. Has no effect if
Defaults to `self.include_sum`, which is set through the class constructor. Has no effect if
`aggregate_probabilities` is set to `True`.
.. warning:: The default value of `include_sum` will switch to `False` in v0.9 in the constructor.
Expand All @@ -466,6 +452,8 @@ def entropy(
"""
if aggregate_probabilities is None:
aggregate_probabilities = self.aggregate_probabilities
if aggregate_probabilities is None:
aggregate_probabilities = False
if not aggregate_probabilities:
return self.entropy_composite(samples_mc, include_sum=include_sum)
se = 0.0
Expand Down
Loading

0 comments on commit 7016a89

Please sign in to comment.