Skip to content

Commit

Permalink
add SAC backbone for DrQ and RLPD
Browse files Browse the repository at this point in the history
  • Loading branch information
Leo428 committed Nov 2, 2023
1 parent def8cc8 commit f02729d
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 9 deletions.
Empty file added serl/agents/sac/__init__.py
Empty file.
238 changes: 238 additions & 0 deletions serl/agents/sac/sac_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
"""Implementations of algorithms for continuous control."""
from functools import partial
from typing import Dict, Optional, Sequence, Tuple

import gym
import jax
import jax.numpy as jnp
import optax
from flax import struct
from flax.training.train_state import TrainState

from serl.agents.agent import Agent
from serl.agents.sac.temperature import Temperature
from serl.data.dataset import DatasetDict
from serl.distributions import TanhNormal
from serl.networks import MLP, Ensemble, StateActionValue, subsample_ensemble


class SACLearner(Agent):
critic: TrainState
target_critic: TrainState
temp: TrainState
tau: float
discount: float
target_entropy: float
num_qs: int = struct.field(pytree_node=False)
num_min_qs: Optional[int] = struct.field(
pytree_node=False
) # See M in RedQ https://arxiv.org/abs/2101.05982
backup_entropy: bool = struct.field(pytree_node=False)

@classmethod
def create(
cls,
seed: int,
observation_space: gym.Space,
action_space: gym.Space,
actor_lr: float = 3e-4,
critic_lr: float = 3e-4,
temp_lr: float = 3e-4,
hidden_dims: Sequence[int] = (256, 256),
discount: float = 0.99,
tau: float = 0.005,
num_qs: int = 2,
num_min_qs: Optional[int] = None,
critic_dropout_rate: Optional[float] = None,
critic_layer_norm: bool = False,
target_entropy: Optional[float] = None,
init_temperature: float = 1.0,
backup_entropy: bool = True,
):
"""
An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1812.05905
"""

action_dim = action_space.shape[-1]
observations = observation_space.sample()
actions = action_space.sample()

if target_entropy is None:
target_entropy = -action_dim / 2

rng = jax.random.PRNGKey(seed)
rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4)

actor_base_cls = partial(MLP, hidden_dims=hidden_dims, activate_final=True)
actor_def = TanhNormal(actor_base_cls, action_dim)
actor_params = actor_def.init(actor_key, observations)["params"]
actor = TrainState.create(
apply_fn=actor_def.apply,
params=actor_params,
tx=optax.adam(learning_rate=actor_lr),
)

critic_base_cls = partial(
MLP,
hidden_dims=hidden_dims,
activate_final=True,
dropout_rate=critic_dropout_rate,
use_layer_norm=critic_layer_norm,
)
critic_cls = partial(StateActionValue, base_cls=critic_base_cls)
critic_def = Ensemble(critic_cls, num=num_qs)
critic_params = critic_def.init(critic_key, observations, actions)["params"]
critic = TrainState.create(
apply_fn=critic_def.apply,
params=critic_params,
tx=optax.adam(learning_rate=critic_lr),
)

target_critic_def = Ensemble(critic_cls, num=num_min_qs or num_qs)
target_critic = TrainState.create(
apply_fn=target_critic_def.apply,
params=critic_params,
tx=optax.GradientTransformation(lambda _: None, lambda _: None),
)

temp_def = Temperature(init_temperature)
temp_params = temp_def.init(temp_key)["params"]
temp = TrainState.create(
apply_fn=temp_def.apply,
params=temp_params,
tx=optax.adam(learning_rate=temp_lr),
)

return cls(
rng=rng,
actor=actor,
critic=critic,
target_critic=target_critic,
temp=temp,
target_entropy=target_entropy,
tau=tau,
discount=discount,
num_qs=num_qs,
num_min_qs=num_min_qs,
backup_entropy=backup_entropy,
)

def update_actor(self, batch: DatasetDict) -> Tuple[Agent, Dict[str, float]]:
key, rng = jax.random.split(self.rng)
key2, rng = jax.random.split(rng)

def actor_loss_fn(actor_params) -> Tuple[jnp.ndarray, Dict[str, float]]:
dist = self.actor.apply_fn({"params": actor_params}, batch["observations"])
actions = dist.sample(seed=key)
log_probs = dist.log_prob(actions)
qs = self.critic.apply_fn(
{"params": self.critic.params},
batch["observations"],
actions,
True,
rngs={"dropout": key2},
) # training=True
q = qs.mean(axis=0)
actor_loss = (
log_probs * self.temp.apply_fn({"params": self.temp.params}) - q
).mean()
return actor_loss, {"actor_loss": actor_loss, "entropy": -log_probs.mean()}

grads, actor_info = jax.grad(actor_loss_fn, has_aux=True)(self.actor.params)
actor = self.actor.apply_gradients(grads=grads)

return self.replace(actor=actor, rng=rng), actor_info

def update_temperature(self, entropy: float) -> Tuple[Agent, Dict[str, float]]:
def temperature_loss_fn(temp_params):
temperature = self.temp.apply_fn({"params": temp_params})
temp_loss = temperature * (entropy - self.target_entropy).mean()
return temp_loss, {
"temperature": temperature,
"temperature_loss": temp_loss,
}

grads, temp_info = jax.grad(temperature_loss_fn, has_aux=True)(self.temp.params)
temp = self.temp.apply_gradients(grads=grads)

return self.replace(temp=temp), temp_info

def update_critic(self, batch: DatasetDict) -> Tuple[TrainState, Dict[str, float]]:

dist = self.actor.apply_fn(
{"params": self.actor.params}, batch["next_observations"]
)

rng = self.rng

key, rng = jax.random.split(rng)
next_actions = dist.sample(seed=key)

# Used only for REDQ.
key, rng = jax.random.split(rng)
target_params = subsample_ensemble(
key, self.target_critic.params, self.num_min_qs, self.num_qs
)

key, rng = jax.random.split(rng)
next_qs = self.target_critic.apply_fn(
{"params": target_params},
batch["next_observations"],
next_actions,
True,
rngs={"dropout": key},
) # training=True
next_q = next_qs.min(axis=0)

target_q = batch["rewards"] + self.discount * batch["masks"] * next_q

if self.backup_entropy:
next_log_probs = dist.log_prob(next_actions)
target_q -= (
self.discount
* batch["masks"]
* self.temp.apply_fn({"params": self.temp.params})
* next_log_probs
)

key, rng = jax.random.split(rng)

def critic_loss_fn(critic_params) -> Tuple[jnp.ndarray, Dict[str, float]]:
qs = self.critic.apply_fn(
{"params": critic_params},
batch["observations"],
batch["actions"],
True,
rngs={"dropout": key},
) # training=True
critic_loss = ((qs - target_q) ** 2).mean()
return critic_loss, {"critic_loss": critic_loss, "q": qs.mean()}

grads, info = jax.grad(critic_loss_fn, has_aux=True)(self.critic.params)
critic = self.critic.apply_gradients(grads=grads)

target_critic_params = optax.incremental_update(
critic.params, self.target_critic.params, self.tau
)
target_critic = self.target_critic.replace(params=target_critic_params)

return self.replace(critic=critic, target_critic=target_critic, rng=rng), info

@partial(jax.jit, static_argnames="utd_ratio")
def update(self, batch: DatasetDict, utd_ratio: int):

new_agent = self
for i in range(utd_ratio):

def slice(x):
assert x.shape[0] % utd_ratio == 0
batch_size = x.shape[0] // utd_ratio
return x[batch_size * i : batch_size * (i + 1)]

mini_batch = jax.tree_util.tree_map(slice, batch)
new_agent, critic_info = new_agent.update_critic(mini_batch)

new_agent, actor_info = new_agent.update_actor(mini_batch)
new_agent, temp_info = new_agent.update_temperature(actor_info["entropy"])

return new_agent, {**actor_info, **critic_info, **temp_info}
14 changes: 14 additions & 0 deletions serl/agents/sac/temperature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import flax.linen as nn
import jax.numpy as jnp


class Temperature(nn.Module):
initial_temperature: float = 1.0

@nn.compact
def __call__(self) -> jnp.ndarray:
log_temp = self.param(
"log_temp",
init_fn=lambda key: jnp.full((), jnp.log(self.initial_temperature)),
)
return jnp.exp(log_temp)
35 changes: 26 additions & 9 deletions serl/utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,36 @@ def get_data(data, start, end):
:param end: the end index of the data range\
:return: eventually returns the numpy array within range
'''

if type(data) == dict:
return {k: get_data(v, start, end) for k,v in data.items()}
return data[start:end]

def restore_checkpoint_(path, item, step):
'''
helper function to restore checkpoints from a path, checks if the path exists
'''
helper function to restore checkpoints from a path, checks if the path exists
:param path: the path to the checkpoints folder
:param item: the TrainState to restore
:param step: the step to restore
:return: the restored TrainState
'''

assert os.path.exists(path)
return checkpoints.restore_checkpoint(path, item, step)

def _reset_weights(source, target):
'''
Reset weights of target to source
TODO: change this to take params directly instead of TrainState
:param source: the source network, TrainState
:param target: the target network, TrainState
'''

:param path: the path to the checkpoints folder
:param item: the TrainState to restore
:param step: the step to restore
:return: the restored TrainState
'''
replacers = {}
for k, v in source.params.items():
if "encoder" not in k:
replacers[k] = v

assert os.path.exists(path)
return checkpoints.restore_checkpoint(path, item, step)
new_params = target.params.copy(add_or_replace=replacers)
return target.replace(params=new_params)

0 comments on commit f02729d

Please sign in to comment.