-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CrossQ #28
Add CrossQ #28
Conversation
…C by replacing the unrolled loop with jax.lax.fori_loop
Implemented CrossQ
@danielpalen after reading the paper, I'm wondering if you have the learning curves for |
sbx/crossq/policies.py
Outdated
if optimizer_kwargs is None: | ||
# Note: the default value for b1 is 0.9 in Adam. | ||
# b1=0.5 is used in the original CrossQ implementation and is found | ||
# but shows only little overall improvement. | ||
optimizer_kwargs = {} | ||
if optimizer_class in [optax.adam, optax.adamw]: | ||
optimizer_kwargs["b1"] = 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, the default value of b1
is set to 0.5
only if no other arguments (or even an empty dict) are passed to the optimizer. It would be cleaner to set the default value to 0.5
regardless of the other optimizer parameters.
if optimizer_kwargs is None: | |
# Note: the default value for b1 is 0.9 in Adam. | |
# b1=0.5 is used in the original CrossQ implementation and is found | |
# but shows only little overall improvement. | |
optimizer_kwargs = {} | |
if optimizer_class in [optax.adam, optax.adamw]: | |
optimizer_kwargs["b1"] = 0.5 | |
if optimizer_kwargs is None: | |
optimizer_kwargs = {} | |
if optimizer_class in [optax.adam, optax.adamw] and "b1" not in optimizer_kwargs: | |
# Note: the default value for b1 is 0.9 in Adam. | |
# b1=0.5 is used in the original CrossQ implementation but shows only little overall improvement. | |
optimizer_kwargs["b1"] = 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would keep it as is to be consistent with what is done in the rest of SB3.
PRNGKey = Any | ||
Array = Any | ||
Shape = Tuple[int, ...] | ||
Dtype = Any # this could be a real type? | ||
Axes = Union[int, Sequence[int]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flax v0.8.1 introduced flax.typing, which we could use here for more descriptive type hints, similar to the current version of flax.linen.normalization. However, we should probably wait a bit here since this would require a relatively recent flax version.
Co-authored-by: Jan Schneider <[email protected]>
@danielpalen Some early results of DroQ + CrossQ (only 2 random seeds on 3 pybullet envs, need more runs): https://wandb.ai/openrlbenchmark/sbx/reports/DroQ-CrossQ-SBX-Perf-Report--Vmlldzo3MzcxNDUy I also quickly checked the warmup steps and could see an impact on |
I quickly checked and it looked pretty similar. |
I have also played around with REDQ/DroQ + CrossQ on MuJoCo but from what I remember, the results were not really consistent, sometimes better, sometimes worse.
That makes sense. If you go to low you don't have a good estimate for the running statistics yet, so you need to give them enough time to warm up. But the exact time will be environment specific I guess |
So far, it always improved the results in my case (need more seeds to confirm, I have tried on different pybullet and mujoco envs), or at least to quickly get "good enough" solution (using up to 2x less samples than CrossQ). One last point in case you missed it (because from #36 (comment)): |
Yes, absolutely :) I put it on my todo. But I think I won't be able to get on that right away at the moment. |
Description
Implementing https://openreview.net/forum?id=PczQtTsTIX
on top of #21
Discussion in #36
perf report:
https://wandb.ai/openrlbenchmark/sbx/reports/CrossQ-SBX-Perf-Report--Vmlldzo3MzQxOTAw
Motivation and Context
Types of changes
Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line