Skip to content

Commit

Permalink
add drq config helper
Browse files Browse the repository at this point in the history
  • Loading branch information
Leo428 committed Dec 21, 2023
1 parent 82e439e commit d8d74df
Showing 1 changed file with 115 additions and 2 deletions.
117 changes: 115 additions & 2 deletions serl_launcher/serl_launcher/utils/jaxrl_m_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@
from collections import deque
from functools import partial
from threading import Lock
from typing import Union, Iterable

import gymnasium as gym
import jax
from serl_launcher.data.serl_replay_buffer import ReplayBuffer
from serl_launcher.data.serl_memory_efficient_replay_buffer import (
MemoryEfficientReplayBuffer,
)

from edgeml.data.data_store import DataStoreBase
from edgeml.trainer import TrainerConfig

from jax import nn
from jaxrl_m.agents.continuous.sac import SACAgent
from jaxrl_m.common.wandb import WandBLogger

##############################################################################
from jaxrl_m.agents.continuous.drq import DrQAgent
from jaxrl_m.vision.small_encoders import SmallEncoder
from jaxrl_m.vision.mobilenet import MobileNetEncoder


class ReplayBufferDataStore(ReplayBuffer, DataStoreBase):
Expand Down Expand Up @@ -51,6 +56,41 @@ def get_latest_data(self, from_id: int):
raise NotImplementedError # TODO


class MemoryEfficientReplayBufferDataStore(MemoryEfficientReplayBuffer, DataStoreBase):
def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
capacity: int,
image_keys: Iterable[str] = ("image",),
):
MemoryEfficientReplayBuffer.__init__(
self, observation_space, action_space, capacity, pixel_keys=image_keys
)
DataStoreBase.__init__(self, capacity)
self._lock = Lock()

# ensure thread safety
def insert(self, *args, **kwargs):
with self._lock:
super(MemoryEfficientReplayBufferDataStore, self).insert(*args, **kwargs)

# ensure thread safety
def sample(self, *args, **kwargs):
with self._lock:
return super(MemoryEfficientReplayBufferDataStore, self).sample(
*args, **kwargs
)

# NOTE: method for DataStoreBase
def latest_data_id(self):
return self._insert_index

# NOTE: method for DataStoreBase
def get_latest_data(self, from_id: int):
raise NotImplementedError # TODO


##############################################################################


Expand All @@ -62,6 +102,79 @@ def make_sac_agent(seed, sample_obs, sample_action):
policy_kwargs={
"tanh_squash_distribution": True,
"std_parameterization": "exp",
"std_min": 1e-5,
"std_max": 5,
},
critic_network_kwargs={
"activations": nn.tanh,
"use_layer_norm": True,
"hidden_dims": [256, 256],
},
policy_network_kwargs={
"activations": nn.tanh,
"use_layer_norm": True,
"hidden_dims": [256, 256],
},
temperature_init=1e-2,
discount=0.99,
backup_entropy=False,
critic_ensemble_size=10,
critic_subsample_size=2,
)


def make_drq_agent(
seed, sample_obs, sample_action, image_keys=("image",), encoder_type="small"
):
if encoder_type == "small":
encoder_defs = {
image_key: SmallEncoder(
features=(32, 64, 128, 256),
kernel_sizes=(3, 3, 3, 3),
strides=(2, 2, 2, 2),
padding="VALID",
pool_method="avg",
bottleneck_dim=256,
spatial_block_size=8,
name=f"encoder_{image_key}",
)
for image_key in image_keys
}
elif encoder_type == "mobilenet":
from jeffnet.linen import create_model

# encoder, encoder_params = create_model('tf_mobilenetv3_large_100', pretrained=True)
encoder, encoder_params = create_model(
"tf_mobilenetv3_small_minimal_100", pretrained=True
)
encoder_defs = {
image_key: MobileNetEncoder(
encoder=encoder,
params=encoder_params,
pool_method="spatial_learned_embeddings",
bottleneck_dim=256,
spatial_block_size=8,
name=f"encoder_{image_key}",
)
for image_key in image_keys
}
elif encoder_type == "resnet":
raise NotImplementedError(f"Unknown encoder type: {encoder_type}")
else:
raise NotImplementedError(f"Unknown encoder type: {encoder_type}")

return DrQAgent.create_drq(
jax.random.PRNGKey(seed),
sample_obs,
sample_action,
encoder=encoder_defs,
use_proprio=True,
image_keys=image_keys,
policy_kwargs={
"tanh_squash_distribution": True,
"std_parameterization": "exp",
"std_min": 1e-5,
"std_max": 5,
},
critic_network_kwargs={
"activations": nn.tanh,
Expand Down

0 comments on commit d8d74df

Please sign in to comment.