Skip to content

Commit

Permalink
Added support for large values for gradient_steps to SAC, TD3, and TQC (
Browse files Browse the repository at this point in the history
#21)

* Added support for large values for gradient_steps to SAC, TD3, and TQC by replacing the unrolled loop with jax.lax.fori_loop

* Add comments

* Hotfix for train signature

* Fixed start index for dynamic_slice_in_dim

* Rename policy delay

* Fix type annotation

* Remove old annotations

* Fix off-by-one and improve type annotation

* Fix typo

* [ci skip] Update README

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
3 people authored Feb 9, 2024
1 parent 0f9163d commit e564074
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 165 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pip install sbx-rl
```python
import gymnasium as gym

from sbx import TQC, DroQ, SAC, PPO, DQN, TD3, DDPG
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ

env = gym.make("Pendulum-v1")

Expand Down Expand Up @@ -156,3 +156,7 @@ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv

To any interested in making the baselines better, there is still some documentation that needs to be done.
If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first.

## Contributors

We would like to thank our contributors: [@jan1854](https://github.com/jan1854).
2 changes: 0 additions & 2 deletions sbx/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,12 @@ def _setup_model(self) -> None:
self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)

if not hasattr(self, "policy") or self.policy is None:
# pytype:disable=not-instantiable
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs,
)
# pytype:enable=not-instantiable

self.key = self.policy.build(self.key, self.lr_schedule)
self.qf = self.policy.qf
Expand Down
2 changes: 0 additions & 2 deletions sbx/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,12 @@ def _setup_model(self) -> None:
super()._setup_model()

if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type]
# pytype:disable=not-instantiable
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs,
)
# pytype:enable=not-instantiable

self.key = self.policy.build(self.key, self.lr_schedule, self.max_grad_norm)

Expand Down
159 changes: 104 additions & 55 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import optax
from flax.training.train_state import TrainState
from gymnasium import spaces
from jax.typing import ArrayLike
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
Expand Down Expand Up @@ -112,14 +113,12 @@ def _setup_model(self) -> None:
super()._setup_model()

if not hasattr(self, "policy") or self.policy is None:
# pytype: disable=not-instantiable
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs,
)
# pytype: enable=not-instantiable

assert isinstance(self.qf_learning_rate, float)

Expand Down Expand Up @@ -183,11 +182,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
assert self.replay_buffer is not None
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
# Pre-compute the indices where we need to update the actor
# This is a hack in order to jit the train loop
# It will compile once per value of policy_delay_indices
policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0}
policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) # type: ignore[assignment]

if isinstance(data.observations, dict):
keys = list(self.observation_space.keys()) # type: ignore[attr-defined]
Expand Down Expand Up @@ -218,7 +212,8 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.target_entropy,
gradient_steps,
data,
policy_delay_indices,
self.policy_delay,
(self._n_updates + 1) % self.policy_delay,
self.policy.qf_state,
self.policy.actor_state,
self.ent_coef_state,
Expand All @@ -237,11 +232,11 @@ def update_critic(
actor_state: TrainState,
qf_state: RLTrainState,
ent_coef_state: TrainState,
observations: np.ndarray,
actions: np.ndarray,
next_observations: np.ndarray,
rewards: np.ndarray,
dones: np.ndarray,
observations: jax.Array,
actions: jax.Array,
next_observations: jax.Array,
rewards: jax.Array,
dones: jax.Array,
key: jax.Array,
):
key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4)
Expand All @@ -265,7 +260,7 @@ def update_critic(
# shape is (batch_size, 1)
target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values

def mse_loss(params, dropout_key):
def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array:
# shape is (n_critics, batch_size, 1)
current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key})
return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum()
Expand All @@ -285,12 +280,12 @@ def update_actor(
actor_state: RLTrainState,
qf_state: RLTrainState,
ent_coef_state: TrainState,
observations: np.ndarray,
observations: jax.Array,
key: jax.Array,
):
key, dropout_key, noise_key = jax.random.split(key, 3)

def actor_loss(params):
def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]:
dist = actor_state.apply_fn(params, observations)
actor_actions = dist.sample(seed=noise_key)
log_prob = dist.log_prob(actor_actions).reshape(-1, 1)
Expand All @@ -314,16 +309,16 @@ def actor_loss(params):

@staticmethod
@jax.jit
def soft_update(tau: float, qf_state: RLTrainState):
def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState:
qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau))
return qf_state

@staticmethod
@jax.jit
def update_temperature(target_entropy: np.ndarray, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params):
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean()
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
return ent_coef_loss

ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)
Expand All @@ -332,62 +327,116 @@ def temperature_loss(temp_params):
return ent_coef_state, ent_coef_loss

@classmethod
@partial(jax.jit, static_argnames=["cls", "gradient_steps"])
def update_actor_and_temperature(
cls,
actor_state: RLTrainState,
qf_state: RLTrainState,
ent_coef_state: TrainState,
observations: jax.Array,
target_entropy: ArrayLike,
key: jax.Array,
):
(actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor(
actor_state,
qf_state,
ent_coef_state,
observations,
key,
)
ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy)
return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key

@classmethod
@partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset"])
def _train(
cls,
gamma: float,
tau: float,
target_entropy: np.ndarray,
target_entropy: ArrayLike,
gradient_steps: int,
data: ReplayBufferSamplesNp,
policy_delay_indices: flax.core.FrozenDict,
policy_delay: int,
policy_delay_offset: int,
qf_state: RLTrainState,
actor_state: TrainState,
ent_coef_state: TrainState,
key,
key: jax.Array,
):
actor_loss_value = jnp.array(0)

for i in range(gradient_steps):

def slice(x, step=i):
assert x.shape[0] % gradient_steps == 0
batch_size = x.shape[0] // gradient_steps
return x[batch_size * step : batch_size * (step + 1)]

assert data.observations.shape[0] % gradient_steps == 0
batch_size = data.observations.shape[0] // gradient_steps

carry = {
"actor_state": actor_state,
"qf_state": qf_state,
"ent_coef_state": ent_coef_state,
"key": key,
"info": {
"actor_loss": jnp.array(0.0),
"qf_loss": jnp.array(0.0),
"ent_coef_loss": jnp.array(0.0),
},
}

def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
# Note: this method must be defined inline because
# `fori_loop` expect a signature fn(index, carry) -> carry
actor_state = carry["actor_state"]
qf_state = carry["qf_state"]
ent_coef_state = carry["ent_coef_state"]
key = carry["key"]
info = carry["info"]
batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size)
batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size)
batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size)
batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size)
batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size)
(
qf_state,
(qf_loss_value, ent_coef_value),
key,
) = SAC.update_critic(
) = cls.update_critic(
gamma,
actor_state,
qf_state,
ent_coef_state,
slice(data.observations),
slice(data.actions),
slice(data.next_observations),
slice(data.rewards),
slice(data.dones),
batch_obs,
batch_act,
batch_next_obs,
batch_rew,
batch_done,
key,
)
qf_state = cls.soft_update(tau, qf_state)

(actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond(
(policy_delay_offset + i) % policy_delay == 0,
# If True:
cls.update_actor_and_temperature,
# If False:
lambda *_: (actor_state, qf_state, ent_coef_state, info["actor_loss"], info["ent_coef_loss"], key),
actor_state,
qf_state,
ent_coef_state,
batch_obs,
target_entropy,
key,
)
qf_state = SAC.soft_update(tau, qf_state)

# hack to be able to jit (n_updates % policy_delay == 0)
if i in policy_delay_indices:
(actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor(
actor_state,
qf_state,
ent_coef_state,
slice(data.observations),
key,
)
ent_coef_state, _ = SAC.update_temperature(target_entropy, ent_coef_state, entropy)
info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value}

return {
"actor_state": actor_state,
"qf_state": qf_state,
"ent_coef_state": ent_coef_state,
"key": key,
"info": info,
}

update_carry = jax.lax.fori_loop(0, gradient_steps, one_update, carry)

return (
qf_state,
actor_state,
ent_coef_state,
update_carry["qf_state"],
update_carry["actor_state"],
update_carry["ent_coef_state"],
key,
(actor_loss_value, qf_loss_value, ent_coef_value),
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]),
)
Loading

0 comments on commit e564074

Please sign in to comment.