-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented CrossQ #36
Conversation
Thanks for the PR, I'll try to have a look at it soon =) (and do additional tests on pybullet envs and others). In the meantime, could you document and explain why the combined forward pass is key for crossq? (I need to re-read the paper). |
Hey, @araffin, I added CrossQ to the README and cleaned up a little. I could not find the documentation; where should I add it? The camera-ready version is definitely worth reading :) We have slightly updated it with much longer training runs, which now train stably up to 5M env steps (as you can also see in the plot above) and some additional ablations. I would be curious to hear about your results on pybullet! Regarding the joint forward pass: |
Thanks for the answer. Regarding:
In the code, just above the concatenated pass
In my tests, it seems needed for performance, not just for saving a forward pass, do you have a idea why this change was needed (this was without target network): What i mean is before that commit, my CrossQ implementation was not working at all. EDIT: or is it still the same problem of action from the buffer vs action from the policy? |
I added more documentation and comments for explanation and refactored as you requested.
Yes, the problem is exactly the differently distributed batches. Let me try to explain the different scenarios.
In conclusion, you see that a joint forward pass is a really easy way to achieve exactly what we want, i.e., considering the mixture. Two consecutive forward passes open up a whole suite of edge cases and things that can go wrong. Although it is probably possible to calculate the correct mixture statistics manually and use two consecutive forward passes, it is very hard, and the implementation is error-prone. This is what makes the joint forward pass so elegant. I hope this explanation cleared up the open questions. |
# Note: in the original paper, after some warmup phase (batch norm phase of 5k steps) | ||
# the constraints are linearly relaxed to r_max/d_max over 40k steps | ||
# Here we only have a warmup phase | ||
is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was there a reason to implement it that way? (simplicity?)
Also, how did you choose warm_up_steps: int = 100_000
?
Because of the policy delay, renorm will be used only after 300_000 steps, is that intented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, honestly simplicity. We did not play around with specific schedules for relaxation or such.
I also have not done super extensive testing on the exact number of warum steps, there might be room for improvement, but overall it seems to pretty robust and it did not seem to matter so much at which point you end up switching, as long as it was not too late. From our initial experiments we know, that vanilla BN tended to become unstable for very long runs, but that everything up to somewhere around 700k was fine. So we simply picked a large enough warmup phase.
The policy delay, in fact, extends the warmup phase, you are right there. I not consider this tbh. But I also don't think it makes a huge difference because as I said we found that in general training was not super sensitive when it came to the exact duration of the warump interval.
Thanks for the reply =) So far, I'm using The only environment where it is hard to get good results is Swimmer env. Side note: be careful with mutable variables as default arguments (https://github.com/satwikkansal/wtfpython#-beware-of-default-mutable-arguments) EDIT: btw, you can make things much faster by fusing gradient steps, for instance |
That is geat news! :) Are you planning to push the pybullet experiments only or also redo MuJoCo?
Scaling up CrossQ is, in fact, exactly what we are also currently looking into. We have also quickly tried combining it with TQC but have not found it to work very well at all. But we also did not have the time to further investigate this yet. But if you get interesting results/insights on any of the two I would be happy to chat about that too!
We have not looked into the Swimmer at all, but maybe we should.
Those are good points, I will keep those in mind :) |
quick note to save the config that might work for swimmer:
|
Both, but I won't do as much as you did, btw, I can also give you access to OpenRL benchmark (you would need to use the RL Zoo if you want to push). The report: |
@danielpalen I'll merge your PR with #28 but I'll investigate a bit the influence of Btw, would you be interested in providing a PyTorch implementation for SB3 contrib? (https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
|
Yes, absolutely. I have it on my todo. But honestly, it might take some time until I have time to start working on that. |
Description
I finished the CrossQ implementation based on the the camera-ready version (https://openreview.net/pdf?id=PczQtTsTIX).
I am able to reproduce the results from the original paper on 10 seeds.
I have not updated the tests, as the test already existed.
Motivation and Context
Types of changes
Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line