Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented CrossQ #36

Merged
merged 7 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Implemented algorithms:
- [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
- [CrossQ](https://arxiv.org/pdf/1902.05605.pdf)


### Install using pip
Expand All @@ -36,7 +37,7 @@ pip install sbx-rl
```python
import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ

env = gym.make("Pendulum-v1", render_mode="human")

Expand All @@ -61,7 +62,7 @@ Since SBX shares the SB3 API, it is compatible with the [RL Zoo](https://github.
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
Expand All @@ -70,6 +71,7 @@ rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

Expand All @@ -89,7 +91,7 @@ The same goes for the enjoy script:
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
Expand All @@ -98,6 +100,7 @@ rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

Expand Down
188 changes: 188 additions & 0 deletions sbx/crossq/batch_renorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
from flax.linen.module import Module, compact, merge_param # pylint: disable=g-multiple-import
from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize
from jax.nn import initializers

PRNGKey = Any
Array = Any
Shape = Tuple[int, ...]
Dtype = Any # this could be a real type?
Axes = Union[int, Sequence[int]]


class BatchRenorm(Module):
danielpalen marked this conversation as resolved.
Show resolved Hide resolved
"""BatchRenorm Module (https://arxiv.org/pdf/1702.03275.pdf).
Adapted from flax.linen.normalization.BatchNorm

Usage Note:
If we define a model with BatchRenorm, for example::

BRN = BatchRenorm(use_running_average=False, momentum=0.99, epsilon=0.001, dtype=jnp.float32)

The initialized variables dict will contain in addition to a 'params'
collection a separate 'batch_stats' collection that will contain all the
running statistics for all the BatchRenorm layers in a model::

vars_initialized = BRN.init(key, x) # {'params': ..., 'batch_stats': ...}

We then update the batch_stats during training by specifying that the
`batch_stats` collection is mutable in the `apply` method for our module.::

vars_in = {'params': params, 'batch_stats': old_batch_stats}
y, mutated_vars = BRN.apply(vars_in, x, mutable=['batch_stats'])
new_batch_stats = mutated_vars['batch_stats']

During eval we would define BRN with `use_running_average=True` and use the
batch_stats collection from training to set the statistics. In this case
we are not mutating the batch statistics collection, and needn't mark it
mutable::

vars_in = {'params': params, 'batch_stats': training_batch_stats}
y = BRN.apply(vars_in, x)

Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used. Else the running statistics will be first updated and then used to normalize.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""

use_running_average: Optional[bool] = None
axis: int = -1
momentum: float = 0.99
epsilon: float = 0.001
warm_up_steps: int = 100_000
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True

@compact
def __call__(self, x, use_running_average: Optional[bool] = None):
"""Normalizes the input using batch statistics.

NOTE:
During initialization (when `self.is_initializing()` is `True`) the running
average of the batch statistics will not be updated. Therefore, the inputs
fed during initialization don't need to match that of the actual input
distribution and the reduction axis (set with `axis_name`) does not have
to exist.

Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.

Returns:
Normalized inputs (the same shape as inputs).
"""

use_running_average = merge_param("use_running_average", self.use_running_average, use_running_average)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
feature_shape = [x.shape[ax] for ax in feature_axes]

ra_mean = self.variable(
"batch_stats",
"mean",
lambda s: jnp.zeros(s, jnp.float32),
feature_shape,
)
ra_var = self.variable("batch_stats", "var", lambda s: jnp.ones(s, jnp.float32), feature_shape)

r_max = self.variable(
"batch_stats",
"r_max",
lambda s: s,
3,
)
d_max = self.variable(
"batch_stats",
"d_max",
lambda s: s,
5,
)
steps = self.variable(
"batch_stats",
"steps",
lambda s: s,
0,
)

if use_running_average:
mean, var = ra_mean.value, ra_var.value
custom_mean = mean
custom_var = var
else:
mean, var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
use_fast_variance=self.use_fast_variance,
)
custom_mean = mean
custom_var = var
if not self.is_initializing():
r = 1
d = 0
std = jnp.sqrt(var + self.epsilon)
ra_std = jnp.sqrt(ra_var.value + self.epsilon)
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max.value, r_max.value)
d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std)
d = jnp.clip(d, -d_max.value, d_max.value)
tmp_var = var / (r**2)
tmp_mean = mean - d * jnp.sqrt(custom_var) / r

is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32)
custom_var = is_warmed_up * tmp_var + (1.0 - is_warmed_up) * custom_var
custom_mean = is_warmed_up * tmp_mean + (1.0 - is_warmed_up) * custom_mean

ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean
ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
steps.value += 1

return _normalize(
self,
x,
custom_mean,
custom_var,
reduction_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
self.use_bias,
self.use_scale,
self.bias_init,
self.scale_init,
)
36 changes: 11 additions & 25 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,15 @@ def __init__(
self,
policy,
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
learning_rate: Union[float, Schedule] = 1e-3,
qf_learning_rate: Optional[float] = None,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
learning_starts: int = 5_000,
danielpalen marked this conversation as resolved.
Show resolved Hide resolved
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
policy_delay: int = 1,
policy_delay: int = 3,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -85,7 +84,6 @@ def __init__(
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
Expand All @@ -106,8 +104,11 @@ def __init__(
self.policy_delay = policy_delay
self.ent_coef_init = ent_coef

# if "optimizer_kwargs" not in self.policy_kwargs:
# self.policy_kwargs["optimizer_kwargs"] = {"b1": 0.5}
if "net_arch" not in self.policy_kwargs:
danielpalen marked this conversation as resolved.
Show resolved Hide resolved
self.policy_kwargs["net_arch"] = {"pi": [256, 256], "qf": [2048, 2048]}

if "optimizer_kwargs" not in self.policy_kwargs:
danielpalen marked this conversation as resolved.
Show resolved Hide resolved
self.policy_kwargs["optimizer_kwargs"] = {"b1": 0.5}

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -256,20 +257,12 @@ def update_critic(

ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params})

# qf_next_values = qf_state.apply_fn(
# {"params": qf_state.params, "batch_stats": qf_state.batch_stats},
# next_observations,
# next_state_actions,
# rngs={"dropout": dropout_key_target},
# train=False, # todo: concatenate with obs, use train=True in that case
# )

# TODO: concatenate obs/next obs
def mse_loss(
params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict
) -> Tuple[jax.Array, jax.Array]:
# Concatenate obs/next_obs to have only one forward pass
# shape is (n_critics, 2 * batch_size, 1)
# Concatenate obs/next_obs to have only one forward pass shape is (n_critics, 2 * batch_size, 1)
# This directly calculates the batch statistics for the mixture distribution of
# state and next_state and actions and next_state_actions
q_values, state_updates = qf_state.apply_fn(
{"params": params, "batch_stats": batch_stats},
jnp.concatenate([observations, next_observations], axis=0),
Expand Down Expand Up @@ -347,12 +340,6 @@ def actor_loss(

return actor_state, qf_state, actor_loss_value, key, entropy

# @staticmethod
# @jax.jit
# def soft_update(tau: float, qf_state: BatchNormTrainState):
# qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau))
# return qf_state

@staticmethod
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
Expand Down Expand Up @@ -447,7 +434,6 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
key,
)
# No target q values with CrossQ
# qf_state = cls.soft_update(tau, qf_state)

(actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond(
(policy_delay_offset + i) % policy_delay == 0,
Expand Down
Loading