Skip to content

Commit

Permalink
Merge pull request #36 from danielpalen/feat/crossq
Browse files Browse the repository at this point in the history
Implemented CrossQ
  • Loading branch information
araffin authored Mar 29, 2024
2 parents cd02332 + ddc6c90 commit e9262a1
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 73 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero

format:
# Sort imports
ruff --select I ${LINT_PATHS} --fix
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
ruff --select I ${LINT_PATHS}
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Implemented algorithms:
- [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)


### Install using pip
Expand All @@ -36,7 +37,7 @@ pip install sbx-rl
```python
import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ

env = gym.make("Pendulum-v1", render_mode="human")

Expand All @@ -61,7 +62,7 @@ Since SBX shares the SB3 API, it is compatible with the [RL Zoo](https://github.
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
Expand All @@ -70,6 +71,7 @@ rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

Expand All @@ -89,7 +91,7 @@ The same goes for the enjoy script:
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
Expand All @@ -98,6 +100,7 @@ rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

Expand Down
208 changes: 208 additions & 0 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
from flax.linen.module import Module, compact, merge_param
from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize
from jax.nn import initializers

PRNGKey = Any
Array = Any
Shape = Tuple[int, ...]
Dtype = Any # this could be a real type?
Axes = Union[int, Sequence[int]]


class BatchRenorm(Module):
"""BatchRenorm Module (https://arxiv.org/abs/1702.03275).
Adapted from flax.linen.normalization.BatchNorm
BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm,
BatchRenorm uses the running statistics for normalizing the batches after a warmup phase.
This makes it less prone to suffer from "outlier" batches that can happen
during very long training runs and, therefore, is more robust during long training runs.
During the warmup phase, it behaves exactly like a BatchNorm layer.
Usage Note:
If we define a model with BatchRenorm, for example::
BRN = BatchRenorm(use_running_average=False, momentum=0.99, epsilon=0.001, dtype=jnp.float32)
The initialized variables dict will contain in addition to a 'params'
collection a separate 'batch_stats' collection that will contain all the
running statistics for all the BatchRenorm layers in a model::
vars_initialized = BRN.init(key, x) # {'params': ..., 'batch_stats': ...}
We then update the batch_stats during training by specifying that the
`batch_stats` collection is mutable in the `apply` method for our module.::
vars_in = {'params': params, 'batch_stats': old_batch_stats}
y, mutated_vars = BRN.apply(vars_in, x, mutable=['batch_stats'])
new_batch_stats = mutated_vars['batch_stats']
During eval we would define BRN with `use_running_average=True` and use the
batch_stats collection from training to set the statistics. In this case
we are not mutating the batch statistics collection, and needn't mark it
mutable::
vars_in = {'params': params, 'batch_stats': training_batch_stats}
y = BRN.apply(vars_in, x)
Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used. Else the running statistics will be first updated and then used to normalize.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""

use_running_average: Optional[bool] = None
axis: int = -1
momentum: float = 0.99
epsilon: float = 0.001
warm_up_steps: int = 100_000
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
# This parameter was added in flax.linen 0.7.2 (08/2023)
# commented out to be compatible with a wider range of jax versions
# TODO: re-activate in some months (04/2024)
# use_fast_variance: bool = True

@compact
def __call__(self, x, use_running_average: Optional[bool] = None):
"""Normalizes the input using batch statistics.
NOTE:
During initialization (when `self.is_initializing()` is `True`) the running
average of the batch statistics will not be updated. Therefore, the inputs
fed during initialization don't need to match that of the actual input
distribution and the reduction axis (set with `axis_name`) does not have
to exist.
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
"""

use_running_average = merge_param("use_running_average", self.use_running_average, use_running_average)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
feature_shape = [x.shape[ax] for ax in feature_axes]

ra_mean = self.variable(
"batch_stats",
"mean",
lambda s: jnp.zeros(s, jnp.float32),
feature_shape,
)
ra_var = self.variable("batch_stats", "var", lambda s: jnp.ones(s, jnp.float32), feature_shape)

r_max = self.variable(
"batch_stats",
"r_max",
lambda s: s,
3,
)
d_max = self.variable(
"batch_stats",
"d_max",
lambda s: s,
5,
)
steps = self.variable(
"batch_stats",
"steps",
lambda s: s,
0,
)

if use_running_average:
mean, var = ra_mean.value, ra_var.value
custom_mean = mean
custom_var = var
else:
mean, var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
# use_fast_variance=self.use_fast_variance,
)
custom_mean = mean
custom_var = var
if not self.is_initializing():
r = jnp.array(1.0)
d = jnp.array(0.0)
std = jnp.sqrt(var + self.epsilon)
ra_std = jnp.sqrt(ra_var.value + self.epsilon)
# scale
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max.value, r_max.value)
# bias
d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std)
d = jnp.clip(d, -d_max.value, d_max.value)

# BatchNorm normalization, using minibatch stats and running average stats
# Because we use _normalize, this is equivalent to
# ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma
# where sigma = sqrt(var)
affine_mean = mean - d * jnp.sqrt(var) / r
affine_var = var / (r**2)

# Note: in the original paper, after some warmup phase (batch norm phase of 5k steps)
# the constraints are linearly relaxed to r_max/d_max over 40k steps
# Here we only have a warmup phase
is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32)
custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * custom_var
custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * custom_mean

ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * mean
ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * var
steps.value += 1

return _normalize(
self,
x,
custom_mean,
custom_var,
reduction_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
self.use_bias,
self.use_scale,
self.bias_init,
self.scale_init,
)
43 changes: 18 additions & 25 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,15 @@ def __init__(
self,
policy,
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
learning_rate: Union[float, Schedule] = 1e-3,
qf_learning_rate: Optional[float] = None,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
policy_delay: int = 1,
policy_delay: int = 3,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -85,7 +84,6 @@ def __init__(
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
Expand All @@ -106,9 +104,6 @@ def __init__(
self.policy_delay = policy_delay
self.ent_coef_init = ent_coef

# if "optimizer_kwargs" not in self.policy_kwargs:
# self.policy_kwargs["optimizer_kwargs"] = {"b1": 0.5}

if _init_setup_model:
self._setup_model()

Expand Down Expand Up @@ -256,20 +251,25 @@ def update_critic(

ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params})

# qf_next_values = qf_state.apply_fn(
# {"params": qf_state.params, "batch_stats": qf_state.batch_stats},
# next_observations,
# next_state_actions,
# rngs={"dropout": dropout_key_target},
# train=False, # todo: concatenate with obs, use train=True in that case
# )

# TODO: concatenate obs/next obs
def mse_loss(
params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict
) -> Tuple[jax.Array, jax.Array]:
# Concatenate obs/next_obs to have only one forward pass
# shape is (n_critics, 2 * batch_size, 1)

# Joint forward pass of obs/next_obs and actions/next_state_actions to have only
# one forward pass with shape (n_critics, 2 * batch_size, 1).
#
# This has two reasons:
# 1. According to the paper obs/actions and next_obs/next_state_actions are differently
# distributed which is the reason why "naively" appling Batch Normalization in SAC fails.
# The batch statistics have to instead be calculated for the mixture distribution of obs/next_obs
# and actions/next_state_actions. Otherwise, next_obs/next_state_actions are perceived as
# out-of-distribution to the Batch Normalization layer, since running statistics are only polyak averaged
# over from the live network and have never seen the next batch which is known to be unstable.
# Without target networks, the joint forward pass is a simple solution to calculate
# the joint batch statistics directly with a single forward pass.
#
# 2. From a computational perspective a single forward pass is simply more efficient than
# two sequential forward passes.
q_values, state_updates = qf_state.apply_fn(
{"params": params, "batch_stats": batch_stats},
jnp.concatenate([observations, next_observations], axis=0),
Expand Down Expand Up @@ -347,12 +347,6 @@ def actor_loss(

return actor_state, qf_state, actor_loss_value, key, entropy

# @staticmethod
# @jax.jit
# def soft_update(tau: float, qf_state: BatchNormTrainState):
# qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau))
# return qf_state

@staticmethod
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
Expand Down Expand Up @@ -447,7 +441,6 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
key,
)
# No target q values with CrossQ
# qf_state = cls.soft_update(tau, qf_state)

(actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond(
(policy_delay_offset + i) % policy_delay == 0,
Expand Down
Loading

0 comments on commit e9262a1

Please sign in to comment.