Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented CrossQ #36

Merged
merged 7 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero

format:
# Sort imports
ruff --select I ${LINT_PATHS} --fix
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
ruff --select I ${LINT_PATHS}
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Implemented algorithms:
- [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)


### Install using pip
Expand All @@ -36,7 +37,7 @@ pip install sbx-rl
```python
import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ

env = gym.make("Pendulum-v1", render_mode="human")

Expand All @@ -61,7 +62,7 @@ Since SBX shares the SB3 API, it is compatible with the [RL Zoo](https://github.
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
Expand All @@ -70,6 +71,7 @@ rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

Expand All @@ -89,7 +91,7 @@ The same goes for the enjoy script:
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
Expand All @@ -98,6 +100,7 @@ rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

Expand Down
208 changes: 208 additions & 0 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
from flax.linen.module import Module, compact, merge_param
from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize
from jax.nn import initializers

PRNGKey = Any
Array = Any
Shape = Tuple[int, ...]
Dtype = Any # this could be a real type?
Axes = Union[int, Sequence[int]]


class BatchRenorm(Module):
"""BatchRenorm Module (https://arxiv.org/abs/1702.03275).
Adapted from flax.linen.normalization.BatchNorm

BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm,
BatchRenorm uses the running statistics for normalizing the batches after a warmup phase.
This makes it less prone to suffer from "outlier" batches that can happen
during very long training runs and, therefore, is more robust during long training runs.

During the warmup phase, it behaves exactly like a BatchNorm layer.

Usage Note:
If we define a model with BatchRenorm, for example::

BRN = BatchRenorm(use_running_average=False, momentum=0.99, epsilon=0.001, dtype=jnp.float32)

The initialized variables dict will contain in addition to a 'params'
collection a separate 'batch_stats' collection that will contain all the
running statistics for all the BatchRenorm layers in a model::

vars_initialized = BRN.init(key, x) # {'params': ..., 'batch_stats': ...}

We then update the batch_stats during training by specifying that the
`batch_stats` collection is mutable in the `apply` method for our module.::

vars_in = {'params': params, 'batch_stats': old_batch_stats}
y, mutated_vars = BRN.apply(vars_in, x, mutable=['batch_stats'])
new_batch_stats = mutated_vars['batch_stats']

During eval we would define BRN with `use_running_average=True` and use the
batch_stats collection from training to set the statistics. In this case
we are not mutating the batch statistics collection, and needn't mark it
mutable::

vars_in = {'params': params, 'batch_stats': training_batch_stats}
y = BRN.apply(vars_in, x)

Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used. Else the running statistics will be first updated and then used to normalize.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""

use_running_average: Optional[bool] = None
axis: int = -1
momentum: float = 0.99
epsilon: float = 0.001
warm_up_steps: int = 100_000
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
# This parameter was added in flax.linen 0.7.2 (08/2023)
# commented out to be compatible with a wider range of jax versions
# TODO: re-activate in some months (04/2024)
# use_fast_variance: bool = True

@compact
def __call__(self, x, use_running_average: Optional[bool] = None):
"""Normalizes the input using batch statistics.

NOTE:
During initialization (when `self.is_initializing()` is `True`) the running
average of the batch statistics will not be updated. Therefore, the inputs
fed during initialization don't need to match that of the actual input
distribution and the reduction axis (set with `axis_name`) does not have
to exist.

Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.

Returns:
Normalized inputs (the same shape as inputs).
"""

use_running_average = merge_param("use_running_average", self.use_running_average, use_running_average)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
feature_shape = [x.shape[ax] for ax in feature_axes]

ra_mean = self.variable(
"batch_stats",
"mean",
lambda s: jnp.zeros(s, jnp.float32),
feature_shape,
)
ra_var = self.variable("batch_stats", "var", lambda s: jnp.ones(s, jnp.float32), feature_shape)

r_max = self.variable(
"batch_stats",
"r_max",
lambda s: s,
3,
)
d_max = self.variable(
"batch_stats",
"d_max",
lambda s: s,
5,
)
steps = self.variable(
"batch_stats",
"steps",
lambda s: s,
0,
)

if use_running_average:
mean, var = ra_mean.value, ra_var.value
custom_mean = mean
custom_var = var
else:
mean, var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
# use_fast_variance=self.use_fast_variance,
)
custom_mean = mean
custom_var = var
if not self.is_initializing():
r = jnp.array(1.0)
d = jnp.array(0.0)
std = jnp.sqrt(var + self.epsilon)
ra_std = jnp.sqrt(ra_var.value + self.epsilon)
# scale
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max.value, r_max.value)
# bias
d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std)
d = jnp.clip(d, -d_max.value, d_max.value)

# BatchNorm normalization, using minibatch stats and running average stats
# Because we use _normalize, this is equivalent to
# ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma
# where sigma = sqrt(var)
affine_mean = mean - d * jnp.sqrt(var) / r
affine_var = var / (r**2)

# Note: in the original paper, after some warmup phase (batch norm phase of 5k steps)
# the constraints are linearly relaxed to r_max/d_max over 40k steps
# Here we only have a warmup phase
is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32)
Comment on lines +183 to +186
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a reason to implement it that way? (simplicity?)

Also, how did you choose warm_up_steps: int = 100_000?
Because of the policy delay, renorm will be used only after 300_000 steps, is that intented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, honestly simplicity. We did not play around with specific schedules for relaxation or such.
I also have not done super extensive testing on the exact number of warum steps, there might be room for improvement, but overall it seems to pretty robust and it did not seem to matter so much at which point you end up switching, as long as it was not too late. From our initial experiments we know, that vanilla BN tended to become unstable for very long runs, but that everything up to somewhere around 700k was fine. So we simply picked a large enough warmup phase.

The policy delay, in fact, extends the warmup phase, you are right there. I not consider this tbh. But I also don't think it makes a huge difference because as I said we found that in general training was not super sensitive when it came to the exact duration of the warump interval.

custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * custom_var
custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * custom_mean

ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * mean
ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * var
steps.value += 1

return _normalize(
self,
x,
custom_mean,
custom_var,
reduction_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
self.use_bias,
self.use_scale,
self.bias_init,
self.scale_init,
)
43 changes: 18 additions & 25 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,15 @@ def __init__(
self,
policy,
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
learning_rate: Union[float, Schedule] = 1e-3,
qf_learning_rate: Optional[float] = None,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
policy_delay: int = 1,
policy_delay: int = 3,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -85,7 +84,6 @@ def __init__(
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
Expand All @@ -106,9 +104,6 @@ def __init__(
self.policy_delay = policy_delay
self.ent_coef_init = ent_coef

# if "optimizer_kwargs" not in self.policy_kwargs:
# self.policy_kwargs["optimizer_kwargs"] = {"b1": 0.5}

if _init_setup_model:
self._setup_model()

Expand Down Expand Up @@ -256,20 +251,25 @@ def update_critic(

ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params})

# qf_next_values = qf_state.apply_fn(
# {"params": qf_state.params, "batch_stats": qf_state.batch_stats},
# next_observations,
# next_state_actions,
# rngs={"dropout": dropout_key_target},
# train=False, # todo: concatenate with obs, use train=True in that case
# )

# TODO: concatenate obs/next obs
def mse_loss(
params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict
) -> Tuple[jax.Array, jax.Array]:
# Concatenate obs/next_obs to have only one forward pass
# shape is (n_critics, 2 * batch_size, 1)

# Joint forward pass of obs/next_obs and actions/next_state_actions to have only
# one forward pass with shape (n_critics, 2 * batch_size, 1).
#
# This has two reasons:
# 1. According to the paper obs/actions and next_obs/next_state_actions are differently
# distributed which is the reason why "naively" appling Batch Normalization in SAC fails.
# The batch statistics have to instead be calculated for the mixture distribution of obs/next_obs
# and actions/next_state_actions. Otherwise, next_obs/next_state_actions are perceived as
# out-of-distribution to the Batch Normalization layer, since running statistics are only polyak averaged
# over from the live network and have never seen the next batch which is known to be unstable.
# Without target networks, the joint forward pass is a simple solution to calculate
# the joint batch statistics directly with a single forward pass.
#
# 2. From a computational perspective a single forward pass is simply more efficient than
# two sequential forward passes.
q_values, state_updates = qf_state.apply_fn(
{"params": params, "batch_stats": batch_stats},
jnp.concatenate([observations, next_observations], axis=0),
Expand Down Expand Up @@ -347,12 +347,6 @@ def actor_loss(

return actor_state, qf_state, actor_loss_value, key, entropy

# @staticmethod
# @jax.jit
# def soft_update(tau: float, qf_state: BatchNormTrainState):
# qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau))
# return qf_state

@staticmethod
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
Expand Down Expand Up @@ -447,7 +441,6 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
key,
)
# No target q values with CrossQ
# qf_state = cls.soft_update(tau, qf_state)

(actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond(
(policy_delay_offset + i) % policy_delay == 0,
Expand Down
Loading