diff --git a/README.md b/README.md index 0235e82..14eaa39 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ pip install sbx-rl ```python import gymnasium as gym -from sbx import TQC, DroQ, SAC, PPO, DQN, TD3, DDPG +from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ env = gym.make("Pendulum-v1") @@ -156,3 +156,7 @@ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv To any interested in making the baselines better, there is still some documentation that needs to be done. If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first. + +## Contributors + +We would like to thank our contributors: [@jan1854](https://github.com/jan1854). diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index b3fc3b4..cde94b8 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -100,14 +100,12 @@ def _setup_model(self) -> None: self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) if not hasattr(self, "policy") or self.policy is None: - # pytype:disable=not-instantiable self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, ) - # pytype:enable=not-instantiable self.key = self.policy.build(self.key, self.lr_schedule) self.qf = self.policy.qf diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index ae2ed98..313f73c 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -168,14 +168,12 @@ def _setup_model(self) -> None: super()._setup_model() if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type] - # pytype:disable=not-instantiable self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, ) - # pytype:enable=not-instantiable self.key = self.policy.build(self.key, self.lr_schedule, self.max_grad_norm) diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index f4e01ef..aab5e59 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -9,6 +9,7 @@ import optax from flax.training.train_state import TrainState from gymnasium import spaces +from jax.typing import ArrayLike from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -112,14 +113,12 @@ def _setup_model(self) -> None: super()._setup_model() if not hasattr(self, "policy") or self.policy is None: - # pytype: disable=not-instantiable self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, ) - # pytype: enable=not-instantiable assert isinstance(self.qf_learning_rate, float) @@ -183,11 +182,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: assert self.replay_buffer is not None # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) - # Pre-compute the indices where we need to update the actor - # This is a hack in order to jit the train loop - # It will compile once per value of policy_delay_indices - policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0} - policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) # type: ignore[assignment] if isinstance(data.observations, dict): keys = list(self.observation_space.keys()) # type: ignore[attr-defined] @@ -218,7 +212,8 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.target_entropy, gradient_steps, data, - policy_delay_indices, + self.policy_delay, + (self._n_updates + 1) % self.policy_delay, self.policy.qf_state, self.policy.actor_state, self.ent_coef_state, @@ -237,11 +232,11 @@ def update_critic( actor_state: TrainState, qf_state: RLTrainState, ent_coef_state: TrainState, - observations: np.ndarray, - actions: np.ndarray, - next_observations: np.ndarray, - rewards: np.ndarray, - dones: np.ndarray, + observations: jax.Array, + actions: jax.Array, + next_observations: jax.Array, + rewards: jax.Array, + dones: jax.Array, key: jax.Array, ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) @@ -265,7 +260,7 @@ def update_critic( # shape is (batch_size, 1) target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values - def mse_loss(params, dropout_key): + def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # shape is (n_critics, batch_size, 1) current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum() @@ -285,12 +280,12 @@ def update_actor( actor_state: RLTrainState, qf_state: RLTrainState, ent_coef_state: TrainState, - observations: np.ndarray, + observations: jax.Array, key: jax.Array, ): key, dropout_key, noise_key = jax.random.split(key, 3) - def actor_loss(params): + def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: dist = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) @@ -314,16 +309,16 @@ def actor_loss(params): @staticmethod @jax.jit - def soft_update(tau: float, qf_state: RLTrainState): + def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState: qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) return qf_state @staticmethod @jax.jit - def update_temperature(target_entropy: np.ndarray, ent_coef_state: TrainState, entropy: float): - def temperature_loss(temp_params): + def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): + def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array: ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) - ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() + ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr] return ent_coef_loss ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) @@ -332,62 +327,116 @@ def temperature_loss(temp_params): return ent_coef_state, ent_coef_loss @classmethod - @partial(jax.jit, static_argnames=["cls", "gradient_steps"]) + def update_actor_and_temperature( + cls, + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + target_entropy: ArrayLike, + key: jax.Array, + ): + (actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor( + actor_state, + qf_state, + ent_coef_state, + observations, + key, + ) + ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) + return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key + + @classmethod + @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset"]) def _train( cls, gamma: float, tau: float, - target_entropy: np.ndarray, + target_entropy: ArrayLike, gradient_steps: int, data: ReplayBufferSamplesNp, - policy_delay_indices: flax.core.FrozenDict, + policy_delay: int, + policy_delay_offset: int, qf_state: RLTrainState, actor_state: TrainState, ent_coef_state: TrainState, - key, + key: jax.Array, ): - actor_loss_value = jnp.array(0) - - for i in range(gradient_steps): - - def slice(x, step=i): - assert x.shape[0] % gradient_steps == 0 - batch_size = x.shape[0] // gradient_steps - return x[batch_size * step : batch_size * (step + 1)] - + assert data.observations.shape[0] % gradient_steps == 0 + batch_size = data.observations.shape[0] // gradient_steps + + carry = { + "actor_state": actor_state, + "qf_state": qf_state, + "ent_coef_state": ent_coef_state, + "key": key, + "info": { + "actor_loss": jnp.array(0.0), + "qf_loss": jnp.array(0.0), + "ent_coef_loss": jnp.array(0.0), + }, + } + + def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + # Note: this method must be defined inline because + # `fori_loop` expect a signature fn(index, carry) -> carry + actor_state = carry["actor_state"] + qf_state = carry["qf_state"] + ent_coef_state = carry["ent_coef_state"] + key = carry["key"] + info = carry["info"] + batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size) + batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) + batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) + batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) + batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) ( qf_state, (qf_loss_value, ent_coef_value), key, - ) = SAC.update_critic( + ) = cls.update_critic( gamma, actor_state, qf_state, ent_coef_state, - slice(data.observations), - slice(data.actions), - slice(data.next_observations), - slice(data.rewards), - slice(data.dones), + batch_obs, + batch_act, + batch_next_obs, + batch_rew, + batch_done, + key, + ) + qf_state = cls.soft_update(tau, qf_state) + + (actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond( + (policy_delay_offset + i) % policy_delay == 0, + # If True: + cls.update_actor_and_temperature, + # If False: + lambda *_: (actor_state, qf_state, ent_coef_state, info["actor_loss"], info["ent_coef_loss"], key), + actor_state, + qf_state, + ent_coef_state, + batch_obs, + target_entropy, key, ) - qf_state = SAC.soft_update(tau, qf_state) - - # hack to be able to jit (n_updates % policy_delay == 0) - if i in policy_delay_indices: - (actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor( - actor_state, - qf_state, - ent_coef_state, - slice(data.observations), - key, - ) - ent_coef_state, _ = SAC.update_temperature(target_entropy, ent_coef_state, entropy) + info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value} + + return { + "actor_state": actor_state, + "qf_state": qf_state, + "ent_coef_state": ent_coef_state, + "key": key, + "info": info, + } + + update_carry = jax.lax.fori_loop(0, gradient_steps, one_update, carry) return ( - qf_state, - actor_state, - ent_coef_state, + update_carry["qf_state"], + update_carry["actor_state"], + update_carry["ent_coef_state"], key, - (actor_loss_value, qf_loss_value, ent_coef_value), + (update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]), ) diff --git a/sbx/td3/td3.py b/sbx/td3/td3.py index dd38877..617d4f3 100644 --- a/sbx/td3/td3.py +++ b/sbx/td3/td3.py @@ -87,14 +87,12 @@ def _setup_model(self) -> None: super()._setup_model() if not hasattr(self, "policy") or self.policy is None: - # pytype: disable=not-instantiable self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, ) - # pytype: enable=not-instantiable assert isinstance(self.qf_learning_rate, float) @@ -125,11 +123,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: assert self.replay_buffer is not None # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) - # Pre-compute the indices where we need to update the actor - # This is a hack in order to jit the train loop - # It will compile once per value of policy_delay_indices - policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0} - policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) # type: ignore[assignment] if isinstance(data.observations, dict): keys = list(self.observation_space.keys()) # type: ignore[attr-defined] @@ -158,7 +151,8 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.tau, gradient_steps, data, - policy_delay_indices, + self.policy_delay, + (self._n_updates + 1) % self.policy_delay, self.target_policy_noise, self.target_noise_clip, self.policy.qf_state, @@ -176,11 +170,11 @@ def update_critic( gamma: float, actor_state: RLTrainState, qf_state: RLTrainState, - observations: np.ndarray, - actions: np.ndarray, - next_observations: np.ndarray, - rewards: np.ndarray, - dones: np.ndarray, + observations: jax.Array, + actions: jax.Array, + next_observations: jax.Array, + rewards: jax.Array, + dones: jax.Array, target_policy_noise: float, target_noise_clip: float, key: jax.Array, @@ -204,7 +198,7 @@ def update_critic( # shape is (batch_size, 1) target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values - def mse_loss(params, dropout_key): + def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # shape is (n_critics, batch_size, 1) current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum() @@ -223,12 +217,12 @@ def mse_loss(params, dropout_key): def update_actor( actor_state: RLTrainState, qf_state: RLTrainState, - observations: np.ndarray, + observations: jax.Array, key: jax.Array, ): key, dropout_key = jax.random.split(key, 2) - def actor_loss(params): + def actor_loss(params: flax.core.FrozenDict) -> jax.Array: actor_actions = actor_state.apply_fn(params, observations) qf_pi = qf_state.apply_fn( @@ -249,7 +243,7 @@ def actor_loss(params): @staticmethod @jax.jit - def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState): + def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState) -> Tuple[RLTrainState, RLTrainState]: qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) actor_state = actor_state.replace( target_params=optax.incremental_update(actor_state.params, actor_state.target_params, tau) @@ -257,60 +251,90 @@ def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState): return qf_state, actor_state @classmethod - @partial(jax.jit, static_argnames=["cls", "gradient_steps"]) + @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset"]) def _train( cls, gamma: float, tau: float, gradient_steps: int, data: ReplayBufferSamplesNp, - policy_delay_indices: flax.core.FrozenDict, + policy_delay: int, + policy_delay_offset: int, target_policy_noise: float, target_noise_clip: float, qf_state: RLTrainState, actor_state: RLTrainState, - key, + key: jax.Array, ): - actor_loss_value = jnp.array(0) - - for i in range(gradient_steps): - - def slice(x, step=i): - assert x.shape[0] % gradient_steps == 0 - batch_size = x.shape[0] // gradient_steps - return x[batch_size * step : batch_size * (step + 1)] - + assert data.observations.shape[0] % gradient_steps == 0 + batch_size = data.observations.shape[0] // gradient_steps + + carry = { + "actor_state": actor_state, + "qf_state": qf_state, + "key": key, + "info": { + "actor_loss": jnp.array(0.0), + "qf_loss": jnp.array(0.0), + }, + } + + def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + # Note: this method must be defined inline because + # `fori_loop` expect a signature fn(index, carry) -> carry + actor_state = carry["actor_state"] + qf_state = carry["qf_state"] + key = carry["key"] + info = carry["info"] + batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size) + batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) + batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) + batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) + batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) ( qf_state, qf_loss_value, key, - ) = TD3.update_critic( + ) = cls.update_critic( gamma, actor_state, qf_state, - slice(data.observations), - slice(data.actions), - slice(data.next_observations), - slice(data.rewards), - slice(data.dones), + batch_obs, + batch_act, + batch_next_obs, + batch_rew, + batch_done, target_policy_noise, target_noise_clip, key, ) - qf_state, actor_state = TD3.soft_update(tau, qf_state, actor_state) + qf_state, actor_state = cls.soft_update(tau, qf_state, actor_state) + + (actor_state, qf_state, actor_loss_value, key) = jax.lax.cond( + (policy_delay_offset + i) % policy_delay == 0, + # If True: + cls.update_actor, + # If False: + lambda *_: (actor_state, qf_state, info["actor_loss"], key), + actor_state, + qf_state, + batch_obs, + key, + ) + info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value} + + return { + "actor_state": actor_state, + "qf_state": qf_state, + "key": key, + "info": info, + } - # hack to be able to jit (n_updates % policy_delay == 0) - if i in policy_delay_indices: - (actor_state, qf_state, actor_loss_value, key) = cls.update_actor( - actor_state, - qf_state, - slice(data.observations), - key, - ) + update_carry = jax.lax.fori_loop(0, gradient_steps, one_update, carry) return ( - qf_state, - actor_state, + update_carry["qf_state"], + update_carry["actor_state"], key, - (actor_loss_value, qf_loss_value), + (update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"]), ) diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 8db0773..4c7b9e6 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -9,6 +9,7 @@ import optax from flax.training.train_state import TrainState from gymnasium import spaces +from jax.typing import ArrayLike from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -114,14 +115,13 @@ def _setup_model(self) -> None: super()._setup_model() if not hasattr(self, "policy") or self.policy is None: - # pytype: disable=not-instantiable self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, **self.policy_kwargs, ) - # pytype: enable=not-instantiable + assert isinstance(self.qf_learning_rate, float) self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) @@ -184,11 +184,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: assert self.replay_buffer is not None # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) - # Pre-compute the indices where we need to update the actor - # This is a hack in order to jit the train loop - # It will compile once per value of policy_delay_indices - policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0} - policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) # type: ignore[assignment] if isinstance(data.observations, dict): keys = list(self.observation_space.keys()) # type: ignore[attr-defined] @@ -220,7 +215,8 @@ def train(self, gradient_steps: int, batch_size: int) -> None: gradient_steps, self.policy.n_target_quantiles, data, - policy_delay_indices, + self.policy_delay, + (self._n_updates + 1) % self.policy_delay, self.policy.qf1_state, self.policy.qf2_state, self.policy.actor_state, @@ -242,11 +238,11 @@ def update_critic( qf1_state: RLTrainState, qf2_state: RLTrainState, ent_coef_state: TrainState, - observations: np.ndarray, - actions: np.ndarray, - next_observations: np.ndarray, - rewards: np.ndarray, - dones: np.ndarray, + observations: jax.Array, + actions: jax.Array, + next_observations: jax.Array, + rewards: jax.Array, + dones: jax.Array, key: jax.Array, ): key, noise_key, dropout_key_1, dropout_key_2 = jax.random.split(key, 4) @@ -289,9 +285,9 @@ def update_critic( # Make target_quantiles broadcastable to (batch_size, n_quantiles, n_target_quantiles). target_quantiles = jnp.expand_dims(target_quantiles, axis=1) - def huber_quantile_loss(params, noise_key): + def huber_quantile_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # Compute huber quantile loss - current_quantiles = qf1_state.apply_fn(params, observations, actions, True, rngs={"dropout": noise_key}) + current_quantiles = qf1_state.apply_fn(params, observations, actions, True, rngs={"dropout": dropout_key}) # convert to shape: (batch_size, n_quantiles, 1) for broadcast current_quantiles = jnp.expand_dims(current_quantiles, axis=-1) @@ -327,12 +323,12 @@ def update_actor( qf1_state: RLTrainState, qf2_state: RLTrainState, ent_coef_state: TrainState, - observations: np.ndarray, + observations: jax.Array, key: jax.Array, ): key, dropout_key_1, dropout_key_2, noise_key = jax.random.split(key, 4) - def actor_loss(params): + def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: dist = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) @@ -369,18 +365,18 @@ def actor_loss(params): @staticmethod @jax.jit - def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState): + def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState) -> Tuple[RLTrainState, RLTrainState]: qf1_state = qf1_state.replace(target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, tau)) qf2_state = qf2_state.replace(target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, tau)) return qf1_state, qf2_state @staticmethod @jax.jit - def update_temperature(target_entropy: np.ndarray, ent_coef_state: TrainState, entropy: float): - def temperature_loss(temp_params): + def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): + def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array: ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) # ent_coef_loss = (jnp.log(ent_coef_value) * (entropy - target_entropy)).mean() - ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() + ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr] return ent_coef_loss ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) @@ -388,68 +384,149 @@ def temperature_loss(temp_params): return ent_coef_state, ent_coef_loss - @staticmethod - @partial(jax.jit, static_argnames=["gradient_steps", "n_target_quantiles"]) + @classmethod + def update_actor_and_temperature( + cls, + actor_state: RLTrainState, + qf1_state: RLTrainState, + qf2_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + target_entropy: ArrayLike, + key: jax.Array, + ): + (actor_state, (qf1_state, qf2_state), actor_loss_value, key, entropy) = cls.update_actor( + actor_state, + qf1_state, + qf2_state, + ent_coef_state, + observations, + key, + ) + ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) + return actor_state, (qf1_state, qf2_state), ent_coef_state, actor_loss_value, ent_coef_loss_value, key + + @classmethod + @partial( + jax.jit, + static_argnames=["cls", "gradient_steps", "n_target_quantiles", "policy_delay", "policy_delay_offset"], + ) def _train( + cls, gamma: float, tau: float, - target_entropy: np.ndarray, + target_entropy: ArrayLike, gradient_steps: int, n_target_quantiles: int, data: ReplayBufferSamplesNp, - policy_delay_indices: flax.core.FrozenDict, + policy_delay: int, + policy_delay_offset: int, qf1_state: RLTrainState, qf2_state: RLTrainState, actor_state: TrainState, ent_coef_state: TrainState, - key, + key: jax.Array, ): - actor_loss_value = jnp.array(0) - - for i in range(gradient_steps): - - def slice(x, step=i): - assert x.shape[0] % gradient_steps == 0 - batch_size = x.shape[0] // gradient_steps - return x[batch_size * step : batch_size * (step + 1)] - + assert data.observations.shape[0] % gradient_steps == 0 + batch_size = data.observations.shape[0] // gradient_steps + + carry = { + "actor_state": actor_state, + "qf1_state": qf1_state, + "qf2_state": qf2_state, + "ent_coef_state": ent_coef_state, + "key": key, + "info": { + "actor_loss": jnp.array(0.0), + "qf1_loss": jnp.array(0.0), + "qf2_loss": jnp.array(0.0), + "ent_coef_loss": jnp.array(0.0), + }, + } + + def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + # Note: this method must be defined inline because + # `fori_loop` expect a signature fn(index, carry) -> carry + actor_state = carry["actor_state"] + qf1_state = carry["qf1_state"] + qf2_state = carry["qf2_state"] + ent_coef_state = carry["ent_coef_state"] + key = carry["key"] + info = carry["info"] + batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size) + batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) + batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) + batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) + batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) ( (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value, ent_coef_value), key, - ) = TQC.update_critic( + ) = cls.update_critic( gamma, n_target_quantiles, actor_state, qf1_state, qf2_state, ent_coef_state, - slice(data.observations), - slice(data.actions), - slice(data.next_observations), - slice(data.rewards), - slice(data.dones), + batch_obs, + batch_act, + batch_next_obs, + batch_rew, + batch_done, key, ) - qf1_state, qf2_state = TQC.soft_update(tau, qf1_state, qf2_state) - - # hack to be able to jit (n_updates % policy_delay == 0) - if i in policy_delay_indices: - (actor_state, (qf1_state, qf2_state), actor_loss_value, key, entropy) = TQC.update_actor( + qf1_state, qf2_state = cls.soft_update(tau, qf1_state, qf2_state) + + (actor_state, (qf1_state, qf2_state), ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond( + (policy_delay_offset + i) % policy_delay == 0, + # If True: + cls.update_actor_and_temperature, + # If False: + lambda *_: ( actor_state, - qf1_state, - qf2_state, + (qf1_state, qf2_state), ent_coef_state, - slice(data.observations), + info["actor_loss"], + info["ent_coef_loss"], key, - ) - ent_coef_state, _ = TQC.update_temperature(target_entropy, ent_coef_state, entropy) + ), + actor_state, + qf1_state, + qf2_state, + ent_coef_state, + batch_obs, + target_entropy, + key, + ) + info = { + "actor_loss": actor_loss_value, + "qf1_loss": qf1_loss_value, + "qf2_loss": qf2_loss_value, + "ent_coef_loss": ent_coef_loss_value, + } + + return { + "actor_state": actor_state, + "qf1_state": qf1_state, + "qf2_state": qf2_state, + "ent_coef_state": ent_coef_state, + "key": key, + "info": info, + } + + update_carry = jax.lax.fori_loop(0, gradient_steps, one_update, carry) return ( - qf1_state, - qf2_state, - actor_state, - ent_coef_state, + update_carry["qf1_state"], + update_carry["qf2_state"], + update_carry["actor_state"], + update_carry["ent_coef_state"], key, - (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_value), + ( + update_carry["info"]["qf1_loss"], + update_carry["info"]["qf2_loss"], + update_carry["info"]["actor_loss"], + update_carry["info"]["ent_coef_loss"], + ), ) diff --git a/sbx/version.txt b/sbx/version.txt index 78bc1ab..d9df1bb 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.10.0 +0.11.0 diff --git a/setup.py b/setup.py index ad2f657..1c519cb 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ ## Example ```python -from sbx import TQC, DroQ, SAC, DQN, PPO, TD3, DDPG +from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ model = TQC("MlpPolicy", "Pendulum-v1", verbose=1) model.learn(total_timesteps=10_000, progress_bar=True)