-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Computation of log prob in composite distribution for TensorDict without batch size #1065
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this!
I'm happy with these changes.
However, I'm confused by the comment in the PR description, can you clarify?
To reiterate: As explained here and in the comments above, if we aggregate the log-prob we must sum over all batch-sizes until the log-prob has the shape of the sample tensordict (by that I mean the root). If we don't aggregate we keep the shape.
This is because you could have
params = TensorDict(
{
"agents0": TensorDict(
{
"params": {
"cont": {
"loc": torch.randn(3, 4, requires_grad=True),
"scale": torch.rand(3, 4, requires_grad=True),
},
("nested", "cont"): {
"loc": torch.randn(3, 4, requires_grad=True),
"scale": torch.rand(3, 4, requires_grad=True),
},
}
},
batch_size=3,
),
"agents1": TensorDict(
{
"params": {
"cont": {
"loc": torch.randn(2, 4, requires_grad=True),
"scale": torch.rand(2, 4, requires_grad=True),
},
("nested", "cont"): {
"loc": torch.randn(2, 4, requires_grad=True),
"scale": torch.rand(2, 4, requires_grad=True),
},
}
},
batch_size=2,
),
"done": torch.ones(1),
}
)
and your log-prob will have to be reduced over all of its dims at this point (because two dists have shape 2 and the other two have shape 3).
Co-authored-by: Vincent Moens <[email protected]>
Thanks for your review! Following my example I would say that the aggregated In your example, sampling from both agents should not have the same batch size and then would aggregate everything to a single scalar. |
But that doesn't matter here (in tensordict lib, "done" isn't anything special). We must account for the example I gave too where actions can have completely different shapes. |
My issue comes from the |
…Dict without batch size (#1065) Co-authored-by: Pau Riba <[email protected]> Co-authored-by: Vincent Moens <[email protected]> (cherry picked from commit e64a4c3)
Description
I believe there's a bug in the latest version of the
ProbabilisticTensorDictModule
when calculating the log prob.Motivation and Context
The log prob is erroneously flattened according to the sample's ndim. More specifically, the ndim of the batch shape of the root level of the sample tensordict. However, the sample's ndim does not always match the batch shape of the distribution itself!
This is the case, for instance, if there are global keys such as
done
orterminated
.For example:
In this case the, using the
ProbabilisticTensorDictModule
, thesample_log_prob
is flattened into a single float instead of having size 3 as expected.This PR is related to #1054.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist