diff --git a/serl_launcher/serl_launcher/utils/jaxrl_m_common.py b/serl_launcher/serl_launcher/utils/jaxrl_m_common.py index d0b40304..6e8525ad 100644 --- a/serl_launcher/serl_launcher/utils/jaxrl_m_common.py +++ b/serl_launcher/serl_launcher/utils/jaxrl_m_common.py @@ -6,10 +6,14 @@ from collections import deque from functools import partial from threading import Lock +from typing import Union, Iterable import gymnasium as gym import jax from serl_launcher.data.serl_replay_buffer import ReplayBuffer +from serl_launcher.data.serl_memory_efficient_replay_buffer import ( + MemoryEfficientReplayBuffer, +) from edgeml.data.data_store import DataStoreBase from edgeml.trainer import TrainerConfig @@ -17,8 +21,9 @@ from jax import nn from jaxrl_m.agents.continuous.sac import SACAgent from jaxrl_m.common.wandb import WandBLogger - -############################################################################## +from jaxrl_m.agents.continuous.drq import DrQAgent +from jaxrl_m.vision.small_encoders import SmallEncoder +from jaxrl_m.vision.mobilenet import MobileNetEncoder class ReplayBufferDataStore(ReplayBuffer, DataStoreBase): @@ -51,6 +56,41 @@ def get_latest_data(self, from_id: int): raise NotImplementedError # TODO +class MemoryEfficientReplayBufferDataStore(MemoryEfficientReplayBuffer, DataStoreBase): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + image_keys: Iterable[str] = ("image",), + ): + MemoryEfficientReplayBuffer.__init__( + self, observation_space, action_space, capacity, pixel_keys=image_keys + ) + DataStoreBase.__init__(self, capacity) + self._lock = Lock() + + # ensure thread safety + def insert(self, *args, **kwargs): + with self._lock: + super(MemoryEfficientReplayBufferDataStore, self).insert(*args, **kwargs) + + # ensure thread safety + def sample(self, *args, **kwargs): + with self._lock: + return super(MemoryEfficientReplayBufferDataStore, self).sample( + *args, **kwargs + ) + + # NOTE: method for DataStoreBase + def latest_data_id(self): + return self._insert_index + + # NOTE: method for DataStoreBase + def get_latest_data(self, from_id: int): + raise NotImplementedError # TODO + + ############################################################################## @@ -62,6 +102,79 @@ def make_sac_agent(seed, sample_obs, sample_action): policy_kwargs={ "tanh_squash_distribution": True, "std_parameterization": "exp", + "std_min": 1e-5, + "std_max": 5, + }, + critic_network_kwargs={ + "activations": nn.tanh, + "use_layer_norm": True, + "hidden_dims": [256, 256], + }, + policy_network_kwargs={ + "activations": nn.tanh, + "use_layer_norm": True, + "hidden_dims": [256, 256], + }, + temperature_init=1e-2, + discount=0.99, + backup_entropy=False, + critic_ensemble_size=10, + critic_subsample_size=2, + ) + + +def make_drq_agent( + seed, sample_obs, sample_action, image_keys=("image",), encoder_type="small" +): + if encoder_type == "small": + encoder_defs = { + image_key: SmallEncoder( + features=(32, 64, 128, 256), + kernel_sizes=(3, 3, 3, 3), + strides=(2, 2, 2, 2), + padding="VALID", + pool_method="avg", + bottleneck_dim=256, + spatial_block_size=8, + name=f"encoder_{image_key}", + ) + for image_key in image_keys + } + elif encoder_type == "mobilenet": + from jeffnet.linen import create_model + + # encoder, encoder_params = create_model('tf_mobilenetv3_large_100', pretrained=True) + encoder, encoder_params = create_model( + "tf_mobilenetv3_small_minimal_100", pretrained=True + ) + encoder_defs = { + image_key: MobileNetEncoder( + encoder=encoder, + params=encoder_params, + pool_method="spatial_learned_embeddings", + bottleneck_dim=256, + spatial_block_size=8, + name=f"encoder_{image_key}", + ) + for image_key in image_keys + } + elif encoder_type == "resnet": + raise NotImplementedError(f"Unknown encoder type: {encoder_type}") + else: + raise NotImplementedError(f"Unknown encoder type: {encoder_type}") + + return DrQAgent.create_drq( + jax.random.PRNGKey(seed), + sample_obs, + sample_action, + encoder=encoder_defs, + use_proprio=True, + image_keys=image_keys, + policy_kwargs={ + "tanh_squash_distribution": True, + "std_parameterization": "exp", + "std_min": 1e-5, + "std_max": 5, }, critic_network_kwargs={ "activations": nn.tanh,