Skip to content

Commit

Permalink
[BugFix] Computation of log prob in composite distribution for Tensor…
Browse files Browse the repository at this point in the history
…Dict without batch size (#1065)

Co-authored-by: Pau Riba <[email protected]>
Co-authored-by: Vincent Moens <[email protected]>
(cherry picked from commit e64a4c3)
  • Loading branch information
priba authored and vmoens committed Nov 4, 2024
1 parent 01a44d2 commit 4dc764c
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 7 deletions.
10 changes: 5 additions & 5 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,17 +392,17 @@ def forward(
if _requires_sample:
out_tensors = self._dist_sample(dist, interaction_type=interaction_type())
if isinstance(out_tensors, TensorDictBase):
tensordict_out.update(out_tensors)
if self.return_log_prob:
kwargs = {}
if isinstance(dist, CompositeDistribution):
kwargs = {"aggregate_probabilities": False}
log_prob = dist.log_prob(tensordict_out, **kwargs)
if log_prob is not tensordict_out:
log_prob = dist.log_prob(out_tensors, **kwargs)
if log_prob is not out_tensors:
# Composite dists return the tensordict_out directly when aggrgate_prob is False
tensordict_out.set(self.log_prob_key, log_prob)
out_tensors.set(self.log_prob_key, log_prob)
else:
tensordict_out.rename_key_(dist.log_prob_key, self.log_prob_key)
out_tensors.rename_key_(dist.log_prob_key, self.log_prob_key)
tensordict_out.update(out_tensors)
else:
if isinstance(out_tensors, Tensor):
out_tensors = (out_tensors,)
Expand Down
134 changes: 132 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,7 +2331,7 @@ def test_prob_module(self, interaction, return_log_prob, map_names):
},
}
},
[3],
batch_size=(3,),
)
in_keys = ["params"]
out_keys = ["cont", ("nested", "cont")]
Expand Down Expand Up @@ -2389,6 +2389,82 @@ def test_prob_module(self, interaction, return_log_prob, map_names):
+ sample_clone.get(key_logprob1).sum(-1),
)

@pytest.mark.parametrize(
"interaction", [InteractionType.MODE, InteractionType.MEAN]
)
@pytest.mark.parametrize("map_names", [True, False])
def test_prob_module_nested(self, interaction, map_names):
params = TensorDict(
{
"agents": TensorDict(
{
"params": {
"cont": {
"loc": torch.randn(3, 4, requires_grad=True),
"scale": torch.rand(3, 4, requires_grad=True),
},
("nested", "cont"): {
"loc": torch.randn(3, 4, requires_grad=True),
"scale": torch.rand(3, 4, requires_grad=True),
},
}
},
batch_size=3,
),
"done": torch.ones(1),
}
)
in_keys = [("agents", "params")]
out_keys = ["cont", ("nested", "cont")]
distribution_map = {
"cont": distributions.Normal,
("nested", "cont"): distributions.Normal,
}
distribution_kwargs = {
"distribution_map": distribution_map,
"log_prob_key": ("agents", "sample_log_prob"),
}
if map_names:
distribution_kwargs.update(
{
"name_map": {
"cont": ("sample", "agents", "cont"),
("nested", "cont"): ("sample", "agents", "nested", "cont"),
}
}
)
out_keys = list(distribution_kwargs["name_map"].values())
module = ProbabilisticTensorDictModule(
in_keys=in_keys,
out_keys=None,
distribution_class=CompositeDistribution,
distribution_kwargs=distribution_kwargs,
default_interaction_type=interaction,
return_log_prob=True,
log_prob_key=("agents", "sample_log_prob"),
)
# loosely checks that the log-prob keys have been added
assert module.out_keys[-2:] != out_keys

sample = module(params)
key_logprob0 = (
("sample", "agents", "cont_log_prob") if map_names else "cont_log_prob"
)
key_logprob1 = (
("sample", "agents", "nested", "cont_log_prob")
if map_names
else ("nested", "cont_log_prob")
)
assert key_logprob0 in sample
assert key_logprob1 in sample
assert all(key in sample for key in module.out_keys)

lp = sample.get(module.log_prob_key)
torch.testing.assert_close(
lp,
sample.get(key_logprob0).sum(-1) + sample.get(key_logprob1).sum(-1),
)

@pytest.mark.parametrize(
"interaction", [InteractionType.MODE, InteractionType.MEAN]
)
Expand All @@ -2407,7 +2483,7 @@ def test_prob_module_seq(self, interaction, return_log_prob):
},
}
},
[3],
batch_size=(3,),
)
in_keys = ["params"]
out_keys = ["cont", ("nested", "cont")]
Expand Down Expand Up @@ -2446,6 +2522,60 @@ def test_prob_module_seq(self, interaction, return_log_prob):
+ sample_clone.get(("nested", "cont_log_prob")).sum(-1),
)

@pytest.mark.parametrize(
"interaction", [InteractionType.MODE, InteractionType.MEAN]
)
def test_prob_module_seq_nested(self, interaction):
params = TensorDict(
{
"agents": TensorDict(
{
"params": {
"cont": {
"loc": torch.randn(3, 4, requires_grad=True),
"scale": torch.rand(3, 4, requires_grad=True),
},
("nested", "cont"): {
"loc": torch.randn(3, 4, requires_grad=True),
"scale": torch.rand(3, 4, requires_grad=True),
},
}
},
batch_size=3,
),
"done": torch.ones(1),
}
)
in_keys = [("agents", "params")]
out_keys = ["cont", ("nested", "cont")]
distribution_map = {
"cont": distributions.Normal,
("nested", "cont"): distributions.Normal,
}
log_prob_key = ("agents", "sample_log_prob")
backbone = TensorDictModule(lambda: None, in_keys=[], out_keys=[])
module = ProbabilisticTensorDictSequential(
backbone,
ProbabilisticTensorDictModule(
in_keys=in_keys,
out_keys=out_keys,
distribution_class=CompositeDistribution,
distribution_kwargs={"distribution_map": distribution_map},
default_interaction_type=interaction,
return_log_prob=True,
log_prob_key=log_prob_key,
),
)
sample = module(params)
assert "cont_log_prob" in sample.keys()
assert ("nested", "cont_log_prob") in sample.keys(True)
lp = sample[log_prob_key]
torch.testing.assert_close(
lp,
sample.get("cont_log_prob").sum(-1)
+ sample.get(("nested", "cont_log_prob")).sum(-1),
)


class TestAddStateIndependentNormalScale:
def test_add_scale_basic(self, num_outputs=4):
Expand Down

0 comments on commit 4dc764c

Please sign in to comment.