Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Computation of log prob in composite distribution for TensorDict without batch size #1065

Merged
merged 2 commits into from
Nov 1, 2024

Conversation

priba
Copy link
Contributor

@priba priba commented Oct 30, 2024

Description

I believe there's a bug in the latest version of the ProbabilisticTensorDictModule when calculating the log prob.

Motivation and Context

The log prob is erroneously flattened according to the sample's ndim. More specifically, the ndim of the batch shape of the root level of the sample tensordict. However, the sample's ndim does not always match the batch shape of the distribution itself!

This is the case, for instance, if there are global keys such as done or terminated.

For example:

params = TensorDict(
    {
        "agents": TensorDict(
            {
                "params": {
                    "cont": {
                        "loc": torch.randn(3, 4, requires_grad=True),
                        "scale": torch.rand(3, 4, requires_grad=True),
                    },
                    ("nested", "cont"): {
                        "loc": torch.randn(3, 4, requires_grad=True),
                        "scale": torch.rand(3, 4, requires_grad=True),
                    },
                }
            },
            batch_size=3,
        ),
        "done": torch.ones(1),
    }
)

In this case the, using the ProbabilisticTensorDictModule, the sample_log_prob is flattened into a single float instead of having size 3 as expected.

This PR is related to #1054.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)

Checklist

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 30, 2024
@vmoens vmoens added the bug Something isn't working label Oct 31, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!
I'm happy with these changes.

However, I'm confused by the comment in the PR description, can you clarify?

To reiterate: As explained here and in the comments above, if we aggregate the log-prob we must sum over all batch-sizes until the log-prob has the shape of the sample tensordict (by that I mean the root). If we don't aggregate we keep the shape.
This is because you could have

params = TensorDict(
    {
        "agents0": TensorDict(
            {
                "params": {
                    "cont": {
                        "loc": torch.randn(3, 4, requires_grad=True),
                        "scale": torch.rand(3, 4, requires_grad=True),
                    },
                    ("nested", "cont"): {
                        "loc": torch.randn(3, 4, requires_grad=True),
                        "scale": torch.rand(3, 4, requires_grad=True),
                    },
                }
            },
            batch_size=3,
        ),
        "agents1": TensorDict(
            {
                "params": {
                    "cont": {
                        "loc": torch.randn(2, 4, requires_grad=True),
                        "scale": torch.rand(2, 4, requires_grad=True),
                    },
                    ("nested", "cont"): {
                        "loc": torch.randn(2, 4, requires_grad=True),
                        "scale": torch.rand(2, 4, requires_grad=True),
                    },
                }
            },
            batch_size=2,
        ),
        "done": torch.ones(1),
    }
)

and your log-prob will have to be reduced over all of its dims at this point (because two dists have shape 2 and the other two have shape 3).

test/test_nn.py Outdated Show resolved Hide resolved
test/test_nn.py Outdated Show resolved Hide resolved
@priba
Copy link
Contributor Author

priba commented Oct 31, 2024

Thanks for your review!

Following my example I would say that the aggregated log_prob should keep the batch size of the sampled actions. In my example, "done" is not sampled from the ProbabilisticTensorDictModule.

In your example, sampling from both agents should not have the same batch size and then would aggregate everything to a single scalar.

@vmoens
Copy link
Contributor

vmoens commented Oct 31, 2024

Following my example I would say that the aggregated log_prob should keep the batch size of the sampled actions. In my example, "done" is not sampled from the ProbabilisticTensorDictModule.

But that doesn't matter here (in tensordict lib, "done" isn't anything special). We must account for the example I gave too where actions can have completely different shapes.
If you need log-probs that are not aggregated, the CompositeDistribution gives you just that.
You could append a tensordictmodule that aggregates things according to any heuristic if you need to.

@vmoens vmoens merged commit e64a4c3 into pytorch:main Nov 1, 2024
48 of 57 checks passed
@priba
Copy link
Contributor Author

priba commented Nov 4, 2024

My issue comes from the ProbabilisticActor (because it inherits from the ProbabilisticTensorDictModule). I think #1054 (comment) is a much clear example of what was going on and what this PR solved.

vmoens pushed a commit that referenced this pull request Nov 4, 2024
…Dict without batch size (#1065)

Co-authored-by: Pau Riba <[email protected]>
Co-authored-by: Vincent Moens <[email protected]>
(cherry picked from commit e64a4c3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants