Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented CrossQ #36

Merged
merged 7 commits into from
Mar 29, 2024
Merged

Implemented CrossQ #36

merged 7 commits into from
Mar 29, 2024

Conversation

danielpalen
Copy link
Contributor

Description

I finished the CrossQ implementation based on the the camera-ready version (https://openreview.net/pdf?id=PczQtTsTIX).

I am able to reproduce the results from the original paper on 10 seeds.
sbx_reproduce

I have not updated the tests, as the test already existed.

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

@araffin araffin self-requested a review March 26, 2024 14:04
@araffin
Copy link
Owner

araffin commented Mar 26, 2024

Thanks for the PR, I'll try to have a look at it soon =) (and do additional tests on pybullet envs and others).

In the meantime, could you document and explain why the combined forward pass is key for crossq? (I need to re-read the paper).

@danielpalen
Copy link
Contributor Author

Hey, @araffin, I added CrossQ to the README and cleaned up a little.

I could not find the documentation; where should I add it?

The camera-ready version is definitely worth reading :) We have slightly updated it with much longer training runs, which now train stably up to 5M env steps (as you can also see in the plot above) and some additional ablations. I would be curious to hear about your results on pybullet!

Regarding the joint forward pass:
In the paper, we recognize that "naively" adding BN on top of SAC fails likely because of the different batch statistics of $(s,a)$ from the replay buffer and $(s',a'\sim\pi(s'))$ where the actions are on policy. In SAC (with target networks), the weights and batch statistics are polyak averaged over from the live network to the target network, which means that the $(s',a')$ batches will likely be perceived as out-of-distribution for the BN layers in the target network. BN itself is known to have stability issues with out-of-distribution samples.
If you want to keep the target network, there might be other more complicated ways to implement this, such as calculating the correct joint statistics, etc. However, probably the easy way is to remove the target network completely. This is how we end up with CrossQ, which trains stably even without target networks. The joint forward pass itself is a nice implementation detail, which saves one forward pass and directly calculates the correct statistics of the mixture distribution.

@araffin
Copy link
Owner

araffin commented Mar 27, 2024

Thanks for the answer.

Regarding:

I could not find the documentation; where should I add it?

In the code, just above the concatenated pass

The joint forward pass itself is a nice implementation detail, which saves one forward pass and directly calculates the correct statistics of the mixture distribution.

In my tests, it seems needed for performance, not just for saving a forward pass, do you have a idea why this change was needed (this was without target network):
03ba862

What i mean is before that commit, my CrossQ implementation was not working at all.
The only thing I can think of so far is that the train=True in the second case changes slightly the stats, but I'm surprised that it doesn't work at all.

EDIT: or is it still the same problem of action from the buffer vs action from the policy?

sbx/crossq/batch_renorm.py Outdated Show resolved Hide resolved
sbx/crossq/crossq.py Outdated Show resolved Hide resolved
sbx/crossq/crossq.py Outdated Show resolved Hide resolved
sbx/crossq/crossq.py Outdated Show resolved Hide resolved
@danielpalen
Copy link
Contributor Author

I added more documentation and comments for explanation and refactored as you requested.

In my tests, it seems needed for performance, not just for saving a forward pass, do you have a idea why this change was needed (this was without target network): 03ba862

What i mean is before that commit, my CrossQ implementation was not working at all. The only thing I can think of so far is that the train=True in the second case changes slightly the stats, but I'm surprised that it doesn't work at all.

EDIT: or is it still the same problem of action from the buffer vs action from the policy?

Yes, the problem is exactly the differently distributed batches. Let me try to explain the different scenarios.

  1. Joint forward pass with train=True (CrossQ): This will nicely calculate the batch statistics of the mixture distribution as needed.

  2. Consecutive forward passes:
    a). with train=False for next_state/next_action: This will not work because the running statistics are only calculated using the state/action batch.
    b). with train=True for both passes: This is also not ideal. On the one hand, the running statistics would be roughly correct; just updated one batch at a time instead of with the mixture batch, but that should be okay. On the other hand, however, the individual batches will be normalized with their individual statistics instead of the mixture statistics as they should be. This normalization would also not match the BN layer in eval mode (i.e., when using the running statistics) since it expects batches from the mixture.
    And now there is even one more case distinction. The last scenario is true when using vanilla BatchNorm or during the warmup phase of BatchRenorm (as there it is using the batch statistics before later switching to only using running statistics). After the warmup phase, this configuration should work for BatchRenorm (of course, if the initial incorrect behavior during warmup did not completely break training already).

In conclusion, you see that a joint forward pass is a really easy way to achieve exactly what we want, i.e., considering the mixture. Two consecutive forward passes open up a whole suite of edge cases and things that can go wrong. Although it is probably possible to calculate the correct mixture statistics manually and use two consecutive forward passes, it is very hard, and the implementation is error-prone. This is what makes the joint forward pass so elegant.

I hope this explanation cleared up the open questions.

Comment on lines +183 to +186
# Note: in the original paper, after some warmup phase (batch norm phase of 5k steps)
# the constraints are linearly relaxed to r_max/d_max over 40k steps
# Here we only have a warmup phase
is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a reason to implement it that way? (simplicity?)

Also, how did you choose warm_up_steps: int = 100_000?
Because of the policy delay, renorm will be used only after 300_000 steps, is that intented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, honestly simplicity. We did not play around with specific schedules for relaxation or such.
I also have not done super extensive testing on the exact number of warum steps, there might be room for improvement, but overall it seems to pretty robust and it did not seem to matter so much at which point you end up switching, as long as it was not too late. From our initial experiments we know, that vanilla BN tended to become unstable for very long runs, but that everything up to somewhere around 700k was fine. So we simply picked a large enough warmup phase.

The policy delay, in fact, extends the warmup phase, you are right there. I not consider this tbh. But I also don't think it makes a huge difference because as I said we found that in general training was not super sensitive when it came to the exact duration of the warump interval.

@araffin
Copy link
Owner

araffin commented Mar 28, 2024

Thanks for the reply =)
I've clean things up a bit (notably batch renorm) and I've got good initial results on pybullet (will try to publish some longer runs on openrl benchmark W&B board).

So far, I'm using qf=[1024, 1024] instead of [2048, 2048] because it speeds things up without compromising performance.
I also checked that CrossQ plays well with the DroQ configuration (adding dropout and UTD=10, policy_delay=10).
I'm curious to combine it with TQC (will do that in a separate branch).
My command lines: https://gist.github.com/araffin/8ba37f07065a57caed786a855bfc1ba5 and https://gist.github.com/araffin/0fb74ab1b23da41df36057808705bc7d

The only environment where it is hard to get good results is Swimmer env.
It can normally be solved in 30 lines of code (see https://arxiv.org/abs/2310.05808 and code: https://gist.github.com/araffin/25159d668e9bad41bf31a595add22c27) and it is known to be hard for RL (https://arxiv.org/abs/2208.07587), but it is difficult to make CrossQ work on it (even with the recommended gamma=0.9999).
I think the only time I could make it work was with the DroQ configuration (in comparison, DroQ can solve it in 30k steps).

Side note: be careful with mutable variables as default arguments (https://github.com/satwikkansal/wtfpython#-beware-of-default-mutable-arguments)

EDIT: btw, you can make things much faster by fusing gradient steps, for instance train_freq:4 gradient_steps:4 or train_freq:1 gradient_steps:4 n_envs:4

@danielpalen
Copy link
Contributor Author

Thanks for the reply =) I've clean things up a bit (notable batch renorm) and I've got good initial results on pybullet (will try to publish some longer runs on openrl benchmark W&B board).

That is geat news! :) Are you planning to push the pybullet experiments only or also redo MuJoCo?

So far, I'm using qf=[1024, 1024] instead of [2048, 2048] because it speeds things up without compromising performance. I also checked that CrossQ plays well with the DroQ configuration (adding dropout and UTD=10, policy_delay=10). I'm curious to combine it with TQC (will do that in a separate branch). My command lines: https://gist.github.com/araffin/8ba37f07065a57caed786a855bfc1ba5 and https://gist.github.com/araffin/0fb74ab1b23da41df36057808705bc7d

Scaling up CrossQ is, in fact, exactly what we are also currently looking into. We have also quickly tried combining it with TQC but have not found it to work very well at all. But we also did not have the time to further investigate this yet. But if you get interesting results/insights on any of the two I would be happy to chat about that too!

The only environment where it is hard to get good results is Swimmer env. It can normally be solved in 30 lines of code (see https://arxiv.org/abs/2310.05808 and code: https://gist.github.com/araffin/25159d668e9bad41bf31a595add22c27) and it is known to be hard for RL (https://arxiv.org/abs/2208.07587), but it is difficult to make CrossQ work on it (even with the recommended gamma=0.9999). I think the only time I could make it work was with the DroQ configuration (in comparison, DroQ can solve it in 30k steps).

We have not looked into the Swimmer at all, but maybe we should.

Side note: be careful of mutable variables as default arguments (https://github.com/satwikkansal/wtfpython#-beware-of-default-mutable-arguments)
EDIT: btw, you can make things much faster by fusing gradient steps, for instance, train_freq:4 gradient_steps:4 or train_freq:1 gradient_steps:4 n_envs:4

Those are good points, I will keep those in mind :)

@araffin
Copy link
Owner

araffin commented Mar 28, 2024

quick note to save the config that might work for swimmer:

python train.py --algo crossq --env Swimmer-v4 --eval-episodes 20 --eval-freq 10000 \
--n-eval-envs 5 --verbose 0 -P -c hyperparams/sac.py \
-params gradient_steps:20 n_envs:2 policy_delay:10 \
learning_rate:0.0003 qf_learning_rate:0.001 \
policy_kwargs:"dict(net_arch=dict(pi=[256,256],qf=[256,256]), dropout_rate=0.01, batch_norm_actor=False)" \
gamma:0.9999

@araffin
Copy link
Owner

araffin commented Mar 29, 2024

Are you planning to push the pybullet experiments only or also redo MuJoCo?

Both, but I won't do as much as you did, btw, I can also give you access to OpenRL benchmark (you would need to use the RL Zoo if you want to push).
I'm doing for now only 3 random seeds and 500k steps to quickly have an overview (each experiment is ~10 minutes).

The report:
https://wandb.ai/openrlbenchmark/sbx/reports/CrossQ-SBX-Perf-Report--Vmlldzo3MzQxOTAw

@araffin
Copy link
Owner

araffin commented Mar 29, 2024

@danielpalen I'll merge your PR with #28 but I'll investigate a bit the influence of warm_up_steps before merging.

Btw, would you be interested in providing a PyTorch implementation for SB3 contrib? (https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

EDIT: the only env where CrossQ underperform is Hopper-v4, any intuition of why?
I just realized I run the experiments with gamma=0.98 like for pybullet.

@araffin araffin merged commit e9262a1 into araffin:feat/crossq Mar 29, 2024
@araffin araffin mentioned this pull request Mar 30, 2024
14 tasks
@danielpalen
Copy link
Contributor Author

would you be interested in providing a PyTorch implementation for SB3 contrib? (https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

Yes, absolutely. I have it on my todo. But honestly, it might take some time until I have time to start working on that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants