Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added CrossQ #453

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,21 @@ jobs:
run: |
python -m pip install --upgrade pip

# Use uv for faster downloads
pip install uv
# Install Atari Roms
pip install autorom
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

# cpu version of pytorch - faster to download
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu

pip install -r requirements.txt
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install full requirements (for additional envs and test tools)
uv pip install --system -r requirements.txt
# Use headless version
pip install opencv-python-headless
pip install -e .[plots,tests]
uv pip install --system opencv-python-headless
uv pip install --system -e .[plots,tests]
- name: Lint with ruff
run: |
make lint
Expand All @@ -62,4 +64,4 @@ jobs:
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
make pytest
15 changes: 9 additions & 6 deletions .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@ jobs:
run: |
python -m pip install --upgrade pip

# Use uv for faster downloads
pip install uv
# Install Atari Roms
pip install autorom
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

# cpu version of pytorch - faster to download
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install full requirements (for additional envs and test tools)
uv pip install --system -r requirements.txt
# Use headless version
pip install opencv-python-headless
pip install -e .[plots,tests]
uv pip install --system opencv-python-headless
uv pip install --system -e .[plots,tests]
- name: Check trained agents
run: |
make check-trained-agents
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
## Release 2.4.0a4 (WIP)
## Release 2.4.0a10 (WIP)

**New algorithm: CrossQ, and better defaults for SAC/TQC on Swimmer-v4 env**

### Breaking Changes
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) (@JacobHA) [W&B report](https://wandb.ai/openrlbenchmark/sbx/reports/SAC-MuJoCo-Swimmer-v4--Vmlldzo3NzM5OTk2)
- Upgraded to SB3 >= 2.4.0

### New Features
- Added `CrossQ` hyperparameters for SB3-contrib (@danielpalen)

### Bug fixes
- Replaced deprecated `huggingface_hub.Repository` when pushing to Hugging Face Hub by the recommended `HfApi` (see https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http) (@cochaviz)

### Documentation

### Other
- Updated PyTorch version to 2.3.1 in the CI
- Updated PyTorch version to 2.4.1 in the CI
- Switched to uv to download packages faster on GitHub CI

## Release 2.3.0 (2024-03-31)

Expand Down
91 changes: 91 additions & 0 deletions hyperparams/crossq.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
MountainCarContinuous-v0:
n_timesteps: !!float 50000
policy: 'MlpPolicy'
learning_rate: !!float 7e-4
buffer_size: 50000
train_freq: 32
gradient_steps: 32
gamma: 0.9999
learning_starts: 100
use_sde: True
policy_delay: 2
policy_kwargs: "dict(use_expln=True, log_std_init=-1, net_arch=[64, 64])"

Pendulum-v1:
n_timesteps: 20000
policy: 'MlpPolicy'
policy_delay: 2
policy_kwargs: "dict(net_arch=[256, 256])"


LunarLanderContinuous-v2:
n_timesteps: !!float 2e5
policy: 'MlpPolicy'
buffer_size: 1000000
learning_starts: 10000


BipedalWalker-v3:
n_timesteps: !!float 2e5
policy: 'MlpPolicy'
buffer_size: 300000
gamma: 0.98
learning_starts: 10000
policy_kwargs: "dict(net_arch=dict(pi=[256, 256], qf=[1024, 1024]))"

# === Mujoco Envs ===

HalfCheetah-v4: &mujoco-defaults
buffer_size: 1_000_000
learning_rate: !!float 1e-3
learning_starts: 5000
n_timesteps: !!float 5e6
policy: 'MlpPolicy'
policy_delay: 3
policy_kwargs: "dict(net_arch=dict(pi=[256, 256], qf=[2048, 2048]))"

Ant-v4:
<<: *mujoco-defaults

Hopper-v4:
<<: *mujoco-defaults

Walker2d-v4:
<<: *mujoco-defaults

Humanoid-v4:
<<: *mujoco-defaults

HumanoidStandup-v4:
<<: *mujoco-defaults

Swimmer-v4:
<<: *mujoco-defaults
gamma: 0.999

# Tuned for SAC, need to check with CrossQ
HalfCheetahBulletEnv-v0: &pybullet-defaults
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
train_freq: 8
gradient_steps: 8
learning_starts: 10000
use_sde: True
policy_kwargs: "dict(use_expln=True, log_std_init=-3)"

# Tuned
AntBulletEnv-v0:
<<: *pybullet-defaults

HopperBulletEnv-v0:
<<: *pybullet-defaults
learning_rate: lin_7.3e-4

Walker2DBulletEnv-v0:
<<: *pybullet-defaults
learning_rate: lin_7.3e-4
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
gym==0.26.2
stable-baselines3[extra_no_roms,tests,docs]>=2.4.0a0,<3.0
stable-baselines3[extra_no_roms,tests,docs]>=2.4.0a10,<3.0
box2d-py==2.3.8
pybullet_envs_gymnasium>=0.4.0
# minigrid
Expand Down
3 changes: 2 additions & 1 deletion rl_zoo3/push_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch as th
import yaml
from huggingface_hub import HfApi, Repository
from huggingface_hub import HfApi
from huggingface_hub.repocard import metadata_save
from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId
from huggingface_sb3.push_to_hub import _evaluate_agent, _generate_replay, generate_metadata
Expand Down Expand Up @@ -83,6 +83,7 @@ def generate_model_card(
RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo<br/>
SB3: https://github.com/DLR-RM/stable-baselines3<br/>
SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
SBX (SB3 + Jax): https://github.com/araffin/sbx

Install the RL Zoo (with SB3 and SB3-Contrib):
```bash
Expand Down
3 changes: 2 additions & 1 deletion rl_zoo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from gymnasium import spaces
from huggingface_hub import HfApi
from huggingface_sb3 import EnvironmentName, ModelName
from sb3_contrib import ARS, QRDQN, TQC, TRPO, RecurrentPPO
from sb3_contrib import ARS, QRDQN, TQC, TRPO, CrossQ, RecurrentPPO
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
Expand All @@ -32,6 +32,7 @@
"td3": TD3,
# SB3 Contrib,
"ars": ARS,
"crossq": CrossQ,
"qrdqn": QRDQN,
"tqc": TQC,
"trpo": TRPO,
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a4
2.4.0a10
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
See https://github.com/DLR-RM/rl-baselines3-zoo
"""
install_requires = [
"sb3_contrib>=2.4.0a4,<3.0",
"sb3_contrib>=2.4.0a10,<3.0",
"gymnasium~=0.29.1",
"huggingface_sb3>=3.0,<4.0",
"tqdm",
Expand All @@ -24,8 +24,7 @@
"pyyaml>=5.1",
"pytablewriter~=1.2",
]
# TODO(antonin): update to rliable>=1.1.0 once PR is merged and released
plots_requires = ["seaborn", "rliable @ git+https://github.com/araffin/rliable@patch-1", "scipy~=1.10"]
plots_requires = ["seaborn", "rliable~=1.2.0", "scipy~=1.10"]
test_requires = [
# for MuJoCo envs v4:
"mujoco~=2.3",
Expand Down
Loading