Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 26, 2024
2 parents 4729e2f + 3dbe083 commit f02115c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
3 changes: 2 additions & 1 deletion tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def log_prob_composite(
"The current default is ``True`` but from v0.9 it will be changed to ``False``. Please adapt your call to `log_prob_composite` accordingly.",
category=DeprecationWarning,
)
slp = 0.0
if include_sum:
slp = 0.0
d = {}
for name, dist in self.dists.items():
d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name))
Expand Down
37 changes: 30 additions & 7 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import pytest
import torch

from tensordict import NonTensorData, NonTensorStack, tensorclass, TensorDict
from tensordict._C import unravel_key_list
from tensordict.nn import (
Expand Down Expand Up @@ -2277,7 +2278,9 @@ def test_log_prob(self):
assert isinstance(lp, torch.Tensor)
assert lp.requires_grad

def test_log_prob_composite(self):
@pytest.mark.parametrize("inplace", [None, True, False])
@pytest.mark.parametrize("include_sum", [None, True, False])
def test_log_prob_composite(self, inplace, include_sum):
params = TensorDict(
{
"cont": {
Expand All @@ -2296,12 +2299,25 @@ def test_log_prob_composite(self):
},
extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}},
aggregate_probabilities=False,
inplace=inplace,
include_sum=include_sum,
)
if include_sum is None:
include_sum = True
if inplace is None:
inplace = True
sample = dist.rsample((4,))
sample = dist.log_prob_composite(sample, include_sum=True)
assert sample.get("cont_log_prob").requires_grad
assert sample.get(("nested", "disc_log_prob")).requires_grad
assert "sample_log_prob" in sample.keys()
sample_lp = dist.log_prob_composite(sample)
assert sample_lp.get("cont_log_prob").requires_grad
assert sample_lp.get(("nested", "disc_log_prob")).requires_grad
if inplace:
assert sample_lp is sample
else:
assert sample_lp is not sample
if include_sum:
assert "sample_log_prob" in sample_lp.keys()
else:
assert "sample_log_prob" not in sample_lp.keys()

def test_entropy(self):
params = TensorDict(
Expand All @@ -2327,7 +2343,8 @@ def test_entropy(self):
assert isinstance(ent, torch.Tensor)
assert ent.requires_grad

def test_entropy_composite(self):
@pytest.mark.parametrize("include_sum", [None, True, False])
def test_entropy_composite(self, include_sum):
params = TensorDict(
{
"cont": {
Expand All @@ -2345,12 +2362,18 @@ def test_entropy_composite(self):
("nested", "disc"): distributions.Categorical,
},
aggregate_probabilities=False,
include_sum=include_sum,
)
if include_sum is None:
include_sum = True
sample = dist.entropy()
assert sample.shape == params.shape == dist._batch_shape
assert sample.get("cont_entropy").requires_grad
assert sample.get(("nested", "disc_entropy")).requires_grad
assert "entropy" in sample.keys()
if include_sum:
assert "entropy" in sample.keys()
else:
assert "entropy" not in sample.keys()

def test_cdf(self):
params = TensorDict(
Expand Down

0 comments on commit f02115c

Please sign in to comment.