From 7f5dd22e2c376a42c3ffc7ed73e9eeeddfdd3f87 Mon Sep 17 00:00:00 2001 From: Daniel Palenicek Date: Mon, 25 Mar 2024 20:14:42 +0100 Subject: [PATCH 1/7] Implemented CrossQ --- sbx/crossq/batch_renorm.py | 188 +++++++++++++++++++++++++++++++++++++ sbx/crossq/crossq.py | 24 ++--- sbx/crossq/policies.py | 25 ++--- 3 files changed, 209 insertions(+), 28 deletions(-) create mode 100644 sbx/crossq/batch_renorm.py diff --git a/sbx/crossq/batch_renorm.py b/sbx/crossq/batch_renorm.py new file mode 100644 index 0000000..556a6ab --- /dev/null +++ b/sbx/crossq/batch_renorm.py @@ -0,0 +1,188 @@ +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp +from flax.linen.module import Module, compact, merge_param # pylint: disable=g-multiple-import +from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize +from jax.nn import initializers + +PRNGKey = Any +Array = Any +Shape = Tuple[int, ...] +Dtype = Any # this could be a real type? +Axes = Union[int, Sequence[int]] + + +class BatchRenorm(Module): + """BatchRenorm Module (https://arxiv.org/pdf/1702.03275.pdf). + Adapted from flax.linen.normalization.BatchNorm + + Usage Note: + If we define a model with BatchRenorm, for example:: + + BRN = BatchRenorm(use_running_average=False, momentum=0.99, epsilon=0.001, dtype=jnp.float32) + + The initialized variables dict will contain in addition to a 'params' + collection a separate 'batch_stats' collection that will contain all the + running statistics for all the BatchRenorm layers in a model:: + + vars_initialized = BRN.init(key, x) # {'params': ..., 'batch_stats': ...} + + We then update the batch_stats during training by specifying that the + `batch_stats` collection is mutable in the `apply` method for our module.:: + + vars_in = {'params': params, 'batch_stats': old_batch_stats} + y, mutated_vars = BRN.apply(vars_in, x, mutable=['batch_stats']) + new_batch_stats = mutated_vars['batch_stats'] + + During eval we would define BRN with `use_running_average=True` and use the + batch_stats collection from training to set the statistics. In this case + we are not mutating the batch statistics collection, and needn't mark it + mutable:: + + vars_in = {'params': params, 'batch_stats': training_batch_stats} + y = BRN.apply(vars_in, x) + + Attributes: + use_running_average: if True, the statistics stored in batch_stats will be + used. Else the running statistics will be first updated and then used to normalize. + axis: the feature or non-batch axis of the input. + momentum: decay rate for the exponential moving average of the batch + statistics. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over the + examples on the first two and last two devices. See `jax.lax.psum` for + more details. + use_fast_variance: If true, use a faster, but less numerically stable, + calculation for the variance. + """ + + use_running_average: Optional[bool] = None + axis: int = -1 + momentum: float = 0.99 + epsilon: float = 0.001 + warm_up_steps: int = 100_000 + dtype: Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + use_bias: bool = True + use_scale: bool = True + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + axis_name: Optional[str] = None + axis_index_groups: Any = None + use_fast_variance: bool = True + + @compact + def __call__(self, x, use_running_average: Optional[bool] = None): + """Normalizes the input using batch statistics. + + NOTE: + During initialization (when `self.is_initializing()` is `True`) the running + average of the batch statistics will not be updated. Therefore, the inputs + fed during initialization don't need to match that of the actual input + distribution and the reduction axis (set with `axis_name`) does not have + to exist. + + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats will be + used instead of computing the batch statistics on the input. + + Returns: + Normalized inputs (the same shape as inputs). + """ + + use_running_average = merge_param("use_running_average", self.use_running_average, use_running_average) + feature_axes = _canonicalize_axes(x.ndim, self.axis) + reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) + feature_shape = [x.shape[ax] for ax in feature_axes] + + ra_mean = self.variable( + "batch_stats", + "mean", + lambda s: jnp.zeros(s, jnp.float32), + feature_shape, + ) + ra_var = self.variable("batch_stats", "var", lambda s: jnp.ones(s, jnp.float32), feature_shape) + + r_max = self.variable( + "batch_stats", + "r_max", + lambda s: s, + 3, + ) + d_max = self.variable( + "batch_stats", + "d_max", + lambda s: s, + 5, + ) + steps = self.variable( + "batch_stats", + "steps", + lambda s: s, + 0, + ) + + if use_running_average: + mean, var = ra_mean.value, ra_var.value + custom_mean = mean + custom_var = var + else: + mean, var = _compute_stats( + x, + reduction_axes, + dtype=self.dtype, + axis_name=self.axis_name if not self.is_initializing() else None, + axis_index_groups=self.axis_index_groups, + use_fast_variance=self.use_fast_variance, + ) + custom_mean = mean + custom_var = var + if not self.is_initializing(): + r = 1 + d = 0 + std = jnp.sqrt(var + self.epsilon) + ra_std = jnp.sqrt(ra_var.value + self.epsilon) + r = jax.lax.stop_gradient(std / ra_std) + r = jnp.clip(r, 1 / r_max.value, r_max.value) + d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std) + d = jnp.clip(d, -d_max.value, d_max.value) + tmp_var = var / (r**2) + tmp_mean = mean - d * jnp.sqrt(custom_var) / r + + is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32) + custom_var = is_warmed_up * tmp_var + (1.0 - is_warmed_up) * custom_var + custom_mean = is_warmed_up * tmp_mean + (1.0 - is_warmed_up) * custom_mean + + ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean + ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var + steps.value += 1 + + return _normalize( + self, + x, + custom_mean, + custom_var, + reduction_axes, + feature_axes, + self.dtype, + self.param_dtype, + self.epsilon, + self.use_bias, + self.use_scale, + self.bias_init, + self.scale_init, + ) diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 31c0358..20396dc 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -53,16 +53,15 @@ def __init__( self, policy, env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 3e-4, + learning_rate: Union[float, Schedule] = 1e-3, qf_learning_rate: Optional[float] = None, buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 100, + learning_starts: int = 5_000, batch_size: int = 256, - tau: float = 0.005, gamma: float = 0.99, train_freq: Union[int, Tuple[int, str]] = 1, gradient_steps: int = 1, - policy_delay: int = 1, + policy_delay: int = 3, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, @@ -85,7 +84,6 @@ def __init__( buffer_size=buffer_size, learning_starts=learning_starts, batch_size=batch_size, - tau=tau, gamma=gamma, train_freq=train_freq, gradient_steps=gradient_steps, @@ -106,8 +104,11 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef - # if "optimizer_kwargs" not in self.policy_kwargs: - # self.policy_kwargs["optimizer_kwargs"] = {"b1": 0.5} + if "net_arch" not in self.policy_kwargs: + self.policy_kwargs["net_arch"] = {"pi": [256, 256], "qf": [2048, 2048]} + + if "optimizer_kwargs" not in self.policy_kwargs: + self.policy_kwargs["optimizer_kwargs"] = {"b1": 0.5} if _init_setup_model: self._setup_model() @@ -256,15 +257,6 @@ def update_critic( ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) - # qf_next_values = qf_state.apply_fn( - # {"params": qf_state.params, "batch_stats": qf_state.batch_stats}, - # next_observations, - # next_state_actions, - # rngs={"dropout": dropout_key_target}, - # train=False, # todo: concatenate with obs, use train=True in that case - # ) - - # TODO: concatenate obs/next obs def mse_loss( params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict ) -> Tuple[jax.Array, jax.Array]: diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index 04e59d4..3ba2c18 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -12,6 +12,7 @@ from sbx.common.distributions import TanhTransformedDistribution from sbx.common.policies import BaseJaxPolicy, Flatten from sbx.common.type_aliases import BatchNormTrainState +from sbx.crossq.batch_renorm import BatchRenorm tfp = tensorflow_probability.substrates.jax tfd = tfp.distributions @@ -22,14 +23,14 @@ class Critic(nn.Module): use_layer_norm: bool = False use_batch_norm: bool = False dropout_rate: Optional[float] = None - batch_norm_momentum: float = 0.9 + batch_norm_momentum: float = 0.99 @nn.compact def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> jnp.ndarray: x = Flatten()(x) x = jnp.concatenate([x, action], -1) if self.use_batch_norm: - x = nn.BatchNorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) + x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) @@ -39,7 +40,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> x = nn.LayerNorm()(x) x = nn.relu(x) if self.use_batch_norm: - x = nn.BatchNorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) + x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) x = nn.Dense(1)(x) return x @@ -49,7 +50,7 @@ class VectorCritic(nn.Module): net_arch: Sequence[int] use_layer_norm: bool = False use_batch_norm: bool = False - batch_norm_momentum: float = 0.9 + batch_norm_momentum: float = 0.99 dropout_rate: Optional[float] = None n_critics: int = 2 @@ -81,7 +82,7 @@ class Actor(nn.Module): log_std_min: float = -20 log_std_max: float = 2 use_batch_norm: bool = False - batch_norm_momentum: float = 0.9 + batch_norm_momentum: float = 0.99 def get_std(self): # Make it work with gSDE @@ -91,16 +92,16 @@ def get_std(self): def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined] x = Flatten()(x) if self.use_batch_norm: - x = nn.BatchNorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) + x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) else: # Create dummy batchstats - nn.BatchNorm(use_running_average=not train)(x) + BatchRenorm(use_running_average=not train)(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) x = nn.relu(x) if self.use_batch_norm: - x = nn.BatchNorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) + x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) mean = nn.Dense(self.action_dim)(x) log_std = nn.Dense(self.action_dim)(x) @@ -123,8 +124,8 @@ def __init__( dropout_rate: float = 0.0, layer_norm: bool = False, batch_norm: bool = True, # for critic - batch_norm_actor: bool = False, - batch_norm_momentum: float = 0.9, + batch_norm_actor: bool = True, + batch_norm_momentum: float = 0.99, # activation_fn: Type[nn.Module] = nn.ReLU, use_sde: bool = False, # Note: most gSDE parameters are not used @@ -162,8 +163,8 @@ def __init__( self.net_arch_qf = net_arch["qf"] else: self.net_arch_pi = [256, 256] - # self.net_arch_qf = [2048, 2048] - self.net_arch_qf = [256, 256] + self.net_arch_qf = [2048, 2048] + # self.net_arch_qf = [256, 256] self.n_critics = n_critics self.use_sde = use_sde From e73cf2e730e9c4f06722fd8308e3f9461d8ccdc5 Mon Sep 17 00:00:00 2001 From: Daniel Palenicek Date: Tue, 26 Mar 2024 19:12:18 +0100 Subject: [PATCH 2/7] Added CrossQ to README --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index ec3f403..bde9c7b 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Implemented algorithms: - [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602) - [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971) +- [CrossQ](https://arxiv.org/pdf/1902.05605.pdf) ### Install using pip @@ -36,7 +37,7 @@ pip install sbx-rl ```python import gymnasium as gym -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ +from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ env = gym.make("Pendulum-v1", render_mode="human") @@ -61,7 +62,7 @@ Since SBX shares the SB3 API, it is compatible with the [RL Zoo](https://github. import rl_zoo3 import rl_zoo3.train from rl_zoo3.train import train -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ +from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN @@ -70,6 +71,7 @@ rl_zoo3.ALGOS["sac"] = SAC rl_zoo3.ALGOS["ppo"] = PPO rl_zoo3.ALGOS["td3"] = TD3 rl_zoo3.ALGOS["tqc"] = TQC +rl_zoo3.ALGOS["crossq"] = CrossQ rl_zoo3.train.ALGOS = rl_zoo3.ALGOS rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS @@ -89,7 +91,7 @@ The same goes for the enjoy script: import rl_zoo3 import rl_zoo3.enjoy from rl_zoo3.enjoy import enjoy -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ +from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN @@ -98,6 +100,7 @@ rl_zoo3.ALGOS["sac"] = SAC rl_zoo3.ALGOS["ppo"] = PPO rl_zoo3.ALGOS["td3"] = TD3 rl_zoo3.ALGOS["tqc"] = TQC +rl_zoo3.ALGOS["crossq"] = CrossQ rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS From 74af2eaf7021ca49f0f6f45b49a5f824ad1f8034 Mon Sep 17 00:00:00 2001 From: Daniel Palenicek Date: Tue, 26 Mar 2024 19:12:43 +0100 Subject: [PATCH 3/7] clean up and comments --- sbx/crossq/crossq.py | 12 +++--------- sbx/crossq/policies.py | 13 +++---------- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 20396dc..208747e 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -260,8 +260,9 @@ def update_critic( def mse_loss( params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict ) -> Tuple[jax.Array, jax.Array]: - # Concatenate obs/next_obs to have only one forward pass - # shape is (n_critics, 2 * batch_size, 1) + # Concatenate obs/next_obs to have only one forward pass shape is (n_critics, 2 * batch_size, 1) + # This directly calculates the batch statistics for the mixture distribution of + # state and next_state and actions and next_state_actions q_values, state_updates = qf_state.apply_fn( {"params": params, "batch_stats": batch_stats}, jnp.concatenate([observations, next_observations], axis=0), @@ -339,12 +340,6 @@ def actor_loss( return actor_state, qf_state, actor_loss_value, key, entropy - # @staticmethod - # @jax.jit - # def soft_update(tau: float, qf_state: BatchNormTrainState): - # qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) - # return qf_state - @staticmethod @jax.jit def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): @@ -439,7 +434,6 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: key, ) # No target q values with CrossQ - # qf_state = cls.soft_update(tau, qf_state) (actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond( (policy_delay_offset + i) % policy_delay == 0, diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index 3ba2c18..441bcf1 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -21,7 +21,7 @@ class Critic(nn.Module): net_arch: Sequence[int] use_layer_norm: bool = False - use_batch_norm: bool = False + use_batch_norm: bool = True dropout_rate: Optional[float] = None batch_norm_momentum: float = 0.99 @@ -49,7 +49,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> class VectorCritic(nn.Module): net_arch: Sequence[int] use_layer_norm: bool = False - use_batch_norm: bool = False + use_batch_norm: bool = True batch_norm_momentum: float = 0.99 dropout_rate: Optional[float] = None n_critics: int = 2 @@ -81,7 +81,7 @@ class Actor(nn.Module): action_dim: int log_std_min: float = -20 log_std_max: float = 2 - use_batch_norm: bool = False + use_batch_norm: bool = True batch_norm_momentum: float = 0.99 def get_std(self): @@ -126,7 +126,6 @@ def __init__( batch_norm: bool = True, # for critic batch_norm_actor: bool = True, batch_norm_momentum: float = 0.99, - # activation_fn: Type[nn.Module] = nn.ReLU, use_sde: bool = False, # Note: most gSDE parameters are not used # this is to keep API consistent with SB3 @@ -229,12 +228,6 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) apply_fn=self.qf.apply, params=qf_params["params"], batch_stats=qf_params["batch_stats"], - # target_params=self.qf.init( - # {"params": qf_key, "dropout": dropout_key, "batch_stats": bn_key}, - # obs, - # action, - # train=False, - # ), tx=self.optimizer_class( learning_rate=qf_learning_rate, # type: ignore[call-arg] **self.optimizer_kwargs, From 13565d7750f9e02cada4c79a662cd9ca19fa77c9 Mon Sep 17 00:00:00 2001 From: Daniel Palenicek Date: Wed, 27 Mar 2024 18:18:27 +0100 Subject: [PATCH 4/7] refactored and added comments --- .../batch_renorm.py => common/jax_layers.py} | 5 ++++ sbx/crossq/crossq.py | 27 ++++++++++++------- sbx/crossq/policies.py | 12 ++++++--- 3 files changed, 31 insertions(+), 13 deletions(-) rename sbx/{crossq/batch_renorm.py => common/jax_layers.py} (95%) diff --git a/sbx/crossq/batch_renorm.py b/sbx/common/jax_layers.py similarity index 95% rename from sbx/crossq/batch_renorm.py rename to sbx/common/jax_layers.py index 556a6ab..838c8fb 100644 --- a/sbx/crossq/batch_renorm.py +++ b/sbx/common/jax_layers.py @@ -17,6 +17,11 @@ class BatchRenorm(Module): """BatchRenorm Module (https://arxiv.org/pdf/1702.03275.pdf). Adapted from flax.linen.normalization.BatchNorm + BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm, + BatchRenorm always uses the running statistics for normalizing the batches. + This makes it less prone to suffer from "outlier" batches that can happen + during very long training runs and, therefore, is more robust during long training runs. + Usage Note: If we define a model with BatchRenorm, for example:: diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 208747e..24d61d4 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -56,7 +56,7 @@ def __init__( learning_rate: Union[float, Schedule] = 1e-3, qf_learning_rate: Optional[float] = None, buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 5_000, + learning_starts: int = 100, batch_size: int = 256, gamma: float = 0.99, train_freq: Union[int, Tuple[int, str]] = 1, @@ -104,12 +104,6 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef - if "net_arch" not in self.policy_kwargs: - self.policy_kwargs["net_arch"] = {"pi": [256, 256], "qf": [2048, 2048]} - - if "optimizer_kwargs" not in self.policy_kwargs: - self.policy_kwargs["optimizer_kwargs"] = {"b1": 0.5} - if _init_setup_model: self._setup_model() @@ -260,9 +254,22 @@ def update_critic( def mse_loss( params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict ) -> Tuple[jax.Array, jax.Array]: - # Concatenate obs/next_obs to have only one forward pass shape is (n_critics, 2 * batch_size, 1) - # This directly calculates the batch statistics for the mixture distribution of - # state and next_state and actions and next_state_actions + + # Joint foward pass of obs/next_obs and actions/next_state_actions to have only + # one forward pass with shape (n_critics, 2 * batch_size, 1). + # + # This has two reasons: + # 1. According to the paper obs/actions and next_obs/next_state_actions are differently + # distributed which is the reason why "naively" appling Batch Normalization in SAC fails. + # The batch statistics have to instead be calculated for the mixture distribution of obs/next_obs + # and actions/next_state_actions. Otherwise, next_obs/next_state_actions are perceived as + # out-of-distribution to the Batch Normalization layer, since running statistics are only polyak averaged + # over from the live network and have never seen the next batch which is known to be unstable. + # Without target networks, the joint forward pass is a simple solution to caluclate + # the joint batch statistics directly with a single forward pass. + # + # 2. From a computational perspective a single forward pass is simply more efficient than + # two sequential forward passes. q_values, state_updates = qf_state.apply_fn( {"params": params, "batch_stats": batch_stats}, jnp.concatenate([observations, next_observations], axis=0), diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index 441bcf1..2fefa7e 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -12,7 +12,7 @@ from sbx.common.distributions import TanhTransformedDistribution from sbx.common.policies import BaseJaxPolicy, Flatten from sbx.common.type_aliases import BatchNormTrainState -from sbx.crossq.batch_renorm import BatchRenorm +from sbx.common.jax_layers import BatchRenorm tfp = tensorflow_probability.substrates.jax tfd = tfp.distributions @@ -136,7 +136,10 @@ def __init__( features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + # Note: the default value for b1 is 0.9 in Adam. + # b1=0.5 is used in the original CrossQ implementation and is found + # but shows only little overall improvement. + optimizer_kwargs: Dict[str, Any] = {"b1": 0.5}, n_critics: int = 2, share_features_extractor: bool = False, ): @@ -154,6 +157,7 @@ def __init__( self.batch_norm = batch_norm self.batch_norm_momentum = batch_norm_momentum self.batch_norm_actor = batch_norm_actor + if net_arch is not None: if isinstance(net_arch, list): self.net_arch_pi = self.net_arch_qf = net_arch @@ -162,8 +166,10 @@ def __init__( self.net_arch_qf = net_arch["qf"] else: self.net_arch_pi = [256, 256] + # While CrossQ already works well with a [256,256] critic network, + # the authors found that a much wider network significantly improves performance. self.net_arch_qf = [2048, 2048] - # self.net_arch_qf = [256, 256] + self.n_critics = n_critics self.use_sde = use_sde From c6e75da7be7319c2af11f5229dea97fe2fae7b9f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 28 Mar 2024 13:53:36 +0100 Subject: [PATCH 5/7] Update doc --- Makefile | 8 ++++---- README.md | 2 +- setup.py | 12 ++++-------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index 2240fdc..0177d5a 100644 --- a/Makefile +++ b/Makefile @@ -12,19 +12,19 @@ type: mypy lint: # stop the build if there are Python syntax errors or undefined names # see https://www.flake8rules.com/ - ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full + ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. - ruff ${LINT_PATHS} --exit-zero + ruff check ${LINT_PATHS} --exit-zero format: # Sort imports - ruff --select I ${LINT_PATHS} --fix + ruff check --select I ${LINT_PATHS} --fix # Reformat using black black ${LINT_PATHS} check-codestyle: # Sort imports - ruff --select I ${LINT_PATHS} + ruff check --select I ${LINT_PATHS} # Reformat using black black --check ${LINT_PATHS} diff --git a/README.md b/README.md index bde9c7b..0fb5721 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Implemented algorithms: - [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602) - [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971) -- [CrossQ](https://arxiv.org/pdf/1902.05605.pdf) +- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX) ### Install using pip diff --git a/setup.py b/setup.py index 1c519cb..5970f32 100644 --- a/setup.py +++ b/setup.py @@ -22,11 +22,12 @@ - [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602) - [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971) +- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX) ## Example ```python -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ +from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ model = TQC("MlpPolicy", "Pendulum-v1", verbose=1) model.learn(total_timesteps=10_000, progress_bar=True) @@ -60,9 +61,9 @@ # Type check "mypy", # Lint code - "ruff", + "ruff>=0.3.1", # Reformat - "black", + "black>=24.2.0,<25", ], }, description="Jax version of Stable Baselines, implementations of reinforcement learning algorithms.", @@ -85,8 +86,3 @@ "Programming Language :: Python :: 3.11", ], ) - -# python setup.py sdist -# python setup.py bdist_wheel -# twine upload --repository-url https://test.pypi.org/legacy/ dist/* -# twine upload dist/* From f2d4e27d3cd434cf83b97c35f85d6a8d617bd8a0 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 28 Mar 2024 13:58:57 +0100 Subject: [PATCH 6/7] Cleanup CrossQ and BatchRenorm --- sbx/common/jax_layers.py | 41 +++++++++++++++++++++++++++------------- sbx/crossq/crossq.py | 8 ++++---- sbx/crossq/policies.py | 24 ++++++++++++++--------- 3 files changed, 47 insertions(+), 26 deletions(-) diff --git a/sbx/common/jax_layers.py b/sbx/common/jax_layers.py index 838c8fb..5733d2c 100644 --- a/sbx/common/jax_layers.py +++ b/sbx/common/jax_layers.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -from flax.linen.module import Module, compact, merge_param # pylint: disable=g-multiple-import +from flax.linen.module import Module, compact, merge_param from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize from jax.nn import initializers @@ -14,14 +14,16 @@ class BatchRenorm(Module): - """BatchRenorm Module (https://arxiv.org/pdf/1702.03275.pdf). + """BatchRenorm Module (https://arxiv.org/abs/1702.03275). Adapted from flax.linen.normalization.BatchNorm BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm, - BatchRenorm always uses the running statistics for normalizing the batches. + BatchRenorm uses the running statistics for normalizing the batches after a warmup phase. This makes it less prone to suffer from "outlier" batches that can happen during very long training runs and, therefore, is more robust during long training runs. + During the warmup phase, it behaves exactly like a BatchNorm layer. + Usage Note: If we define a model with BatchRenorm, for example:: @@ -87,7 +89,10 @@ class BatchRenorm(Module): scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones axis_name: Optional[str] = None axis_index_groups: Any = None - use_fast_variance: bool = True + # This parameter was added in flax.linen 0.7.2 (08/2023) + # commented out to be compatible with a wider range of jax versions + # TODO: re-activate in some months (04/2024) + # use_fast_variance: bool = True @compact def __call__(self, x, use_running_average: Optional[bool] = None): @@ -152,28 +157,38 @@ def __call__(self, x, use_running_average: Optional[bool] = None): dtype=self.dtype, axis_name=self.axis_name if not self.is_initializing() else None, axis_index_groups=self.axis_index_groups, - use_fast_variance=self.use_fast_variance, + # use_fast_variance=self.use_fast_variance, ) custom_mean = mean custom_var = var if not self.is_initializing(): - r = 1 - d = 0 + r = jnp.array(1.0) + d = jnp.array(0.0) std = jnp.sqrt(var + self.epsilon) ra_std = jnp.sqrt(ra_var.value + self.epsilon) + # scale r = jax.lax.stop_gradient(std / ra_std) r = jnp.clip(r, 1 / r_max.value, r_max.value) + # bias d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std) d = jnp.clip(d, -d_max.value, d_max.value) - tmp_var = var / (r**2) - tmp_mean = mean - d * jnp.sqrt(custom_var) / r + # BatchNorm normalization, using minibatch stats and running average stats + # Because we use _normalize, this is equivalent to + # ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma + # where sigma = sqrt(var) + affine_mean = mean - d * jnp.sqrt(var) / r + affine_var = var / (r**2) + + # Note: in the original paper, after some warmup phase (batch norm phase of 5k steps) + # the constraints are linearly relaxed to r_max/d_max over 40k steps + # Here we only have a warmup phase is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32) - custom_var = is_warmed_up * tmp_var + (1.0 - is_warmed_up) * custom_var - custom_mean = is_warmed_up * tmp_mean + (1.0 - is_warmed_up) * custom_mean + custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * custom_var + custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * custom_mean - ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean - ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var + ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * mean + ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * var steps.value += 1 return _normalize( diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 24d61d4..52d040c 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -254,18 +254,18 @@ def update_critic( def mse_loss( params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict ) -> Tuple[jax.Array, jax.Array]: - - # Joint foward pass of obs/next_obs and actions/next_state_actions to have only + + # Joint forward pass of obs/next_obs and actions/next_state_actions to have only # one forward pass with shape (n_critics, 2 * batch_size, 1). # # This has two reasons: # 1. According to the paper obs/actions and next_obs/next_state_actions are differently # distributed which is the reason why "naively" appling Batch Normalization in SAC fails. - # The batch statistics have to instead be calculated for the mixture distribution of obs/next_obs + # The batch statistics have to instead be calculated for the mixture distribution of obs/next_obs # and actions/next_state_actions. Otherwise, next_obs/next_state_actions are perceived as # out-of-distribution to the Batch Normalization layer, since running statistics are only polyak averaged # over from the live network and have never seen the next batch which is known to be unstable. - # Without target networks, the joint forward pass is a simple solution to caluclate + # Without target networks, the joint forward pass is a simple solution to calculate # the joint batch statistics directly with a single forward pass. # # 2. From a computational perspective a single forward pass is simply more efficient than diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index 2fefa7e..93054d8 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -10,9 +10,9 @@ from stable_baselines3.common.type_aliases import Schedule from sbx.common.distributions import TanhTransformedDistribution +from sbx.common.jax_layers import BatchRenorm from sbx.common.policies import BaseJaxPolicy, Flatten from sbx.common.type_aliases import BatchNormTrainState -from sbx.common.jax_layers import BatchRenorm tfp = tensorflow_probability.substrates.jax tfd = tfp.distributions @@ -136,13 +136,18 @@ def __init__( features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - # Note: the default value for b1 is 0.9 in Adam. - # b1=0.5 is used in the original CrossQ implementation and is found - # but shows only little overall improvement. - optimizer_kwargs: Dict[str, Any] = {"b1": 0.5}, + optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): + if optimizer_kwargs is None: + # Note: the default value for b1 is 0.9 in Adam. + # b1=0.5 is used in the original CrossQ implementation and is found + # but shows only little overall improvement. + optimizer_kwargs = {} + if optimizer_class in [optax.adam, optax.adamw]: + optimizer_kwargs["b1"] = 0.5 + super().__init__( observation_space, action_space, @@ -157,7 +162,7 @@ def __init__( self.batch_norm = batch_norm self.batch_norm_momentum = batch_norm_momentum self.batch_norm_actor = batch_norm_actor - + if net_arch is not None: if isinstance(net_arch, list): self.net_arch_pi = self.net_arch_qf = net_arch @@ -166,9 +171,10 @@ def __init__( self.net_arch_qf = net_arch["qf"] else: self.net_arch_pi = [256, 256] - # While CrossQ already works well with a [256,256] critic network, - # the authors found that a much wider network significantly improves performance. - self.net_arch_qf = [2048, 2048] + # While CrossQ already works with a [256,256] critic network, + # the authors found that a wider network significantly improves performance. + # We use a slightly smaller net for faster computation, [1024, 1024] instead of [2048, 2048] in the paper + self.net_arch_qf = [1024, 1024] self.n_critics = n_critics self.use_sde = use_sde From ddc6c901efd71519cba3ddba2e91cfa2d0a1e4ab Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 28 Mar 2024 14:37:44 +0100 Subject: [PATCH 7/7] Update tests --- sbx/crossq/policies.py | 17 ++++++++++-- tests/test_run.py | 61 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index 93054d8..a2c58d5 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -24,13 +24,21 @@ class Critic(nn.Module): use_batch_norm: bool = True dropout_rate: Optional[float] = None batch_norm_momentum: float = 0.99 + renorm_warm_up_steps: int = 100_000 @nn.compact def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> jnp.ndarray: x = Flatten()(x) x = jnp.concatenate([x, action], -1) if self.use_batch_norm: - x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warm_up_steps=self.renorm_warm_up_steps, + )(x) + else: + # Create dummy batchstats + BatchRenorm(use_running_average=not train)(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) @@ -83,6 +91,7 @@ class Actor(nn.Module): log_std_max: float = 2 use_batch_norm: bool = True batch_norm_momentum: float = 0.99 + renorm_warm_up_steps: int = 100_000 def get_std(self): # Make it work with gSDE @@ -92,7 +101,11 @@ def get_std(self): def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined] x = Flatten()(x) if self.use_batch_norm: - x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warm_up_steps=self.renorm_warm_up_steps, + )(x) else: # Create dummy batchstats BatchRenorm(use_running_average=not train)(x) diff --git a/tests/test_run.py b/tests/test_run.py index 919b096..cfc7d01 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -10,6 +10,19 @@ from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ +def check_save_load(model, model_class, tmp_path): + # Test save/load + env = model.get_env() + obs = env.observation_space.sample() + model.save(tmp_path / "test_save.zip") + action_before = model.predict(obs, deterministic=True)[0] + # Check we have the same prediction + model = model_class.load(tmp_path / "test_save.zip") + action_after = model.predict(obs, deterministic=True)[0] + assert np.allclose(action_before, action_after) + return model + + def test_droq(tmp_path): model = DroQ( "MlpPolicy", @@ -31,20 +44,18 @@ def test_droq(tmp_path): # Check that something was learned evaluate_policy(model, model.get_env(), reward_threshold=-800) model.save(tmp_path / "test_save.zip") + env = model.get_env() - obs = env.observation_space.sample() - action_before = model.predict(obs, deterministic=True)[0] + model = check_save_load(model, DroQ, tmp_path) # Check we have the same performance - model = DroQ.load(tmp_path / "test_save.zip") evaluate_policy(model, env, reward_threshold=-800) - action_after = model.predict(obs, deterministic=True)[0] - assert np.allclose(action_before, action_after) + # Continue training model.set_env(env, force_reset=False) model.learn(100, reset_num_timesteps=False) -def test_tqc() -> None: +def test_tqc(tmp_path) -> None: # Multi env train_env = make_vec_env("Pendulum-v1", n_envs=4) model = TQC( @@ -58,10 +69,11 @@ def test_tqc() -> None: qf_learning_rate=1e-3, ) model.learn(200) + check_save_load(model, TQC, tmp_path) @pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, CrossQ]) -def test_sac_td3(model_class) -> None: +def test_sac_td3(tmp_path, model_class) -> None: model = model_class( "MlpPolicy", "Pendulum-v1", @@ -70,11 +82,35 @@ def test_sac_td3(model_class) -> None: learning_rate=1e-3, ) model.learn(110) - # model.learn(20_000, progress_bar=True) + check_save_load(model, model_class, tmp_path) + + +@pytest.mark.parametrize("model_class", [SAC, CrossQ]) +def test_dropout(model_class): + kwargs = {} + # Try activating layer norm and dropout + policy_kwargs = dict(dropout_rate=0.01, net_arch=[64], layer_norm=True) + if model_class == CrossQ: + # Try deactivating batch norm + policy_kwargs["batch_norm"] = False + policy_kwargs["batch_norm_actor"] = False + kwargs["ent_coef"] = 0.01 # constant entropy coeff + elif model_class == SAC: + policy_kwargs["net_arch"] = dict(pi=[32], qf=[16]) + + model = model_class( + "MlpPolicy", + "Pendulum-v1", + verbose=1, + gradient_steps=1, + learning_rate=1e-3, + policy_kwargs=policy_kwargs, + ) + model.learn(110) @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) -def test_ppo(env_id: str) -> None: +def test_ppo(tmp_path, env_id: str) -> None: model = PPO( "MlpPolicy", env_id, @@ -82,10 +118,12 @@ def test_ppo(env_id: str) -> None: n_steps=64, n_epochs=2, ) - model.learn(128, progress_bar=True) + model.learn(256, progress_bar=True) + + check_save_load(model, PPO, tmp_path) -def test_dqn() -> None: +def test_dqn(tmp_path) -> None: model = DQN( "MlpPolicy", "CartPole-v1", @@ -94,6 +132,7 @@ def test_dqn() -> None: target_update_interval=10, ) model.learn(128) + check_save_load(model, DQN, tmp_path) @pytest.mark.parametrize("replay_buffer_class", [None, HerReplayBuffer])