diff --git a/.gitignore b/.gitignore index 3d8a3939..00795270 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__/ # MUJOCO_LOG.TXT MUJOCO_LOG.TXT +*.pkl diff --git a/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py b/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py new file mode 100644 index 00000000..835176a2 --- /dev/null +++ b/examples/async_rlpd_drq_sim/async_rlpd_drq_sim.py @@ -0,0 +1,392 @@ +#!/usr/bin/env python3 + +# NOTE: this requires jaxrl_m to be installed: +# https://github.com/rail-berkeley/jaxrl_minimal + +import time +from functools import partial +import jax +import jax.numpy as jnp +import numpy as np +import tqdm +from absl import app, flags + +import gymnasium as gym +from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics + +from jaxrl_m.agents.continuous.drq import DrQAgent +from jaxrl_m.common.evaluation import evaluate +from jaxrl_m.utils.timer_utils import Timer +from jaxrl_m.envs.wrappers.chunking import ChunkingWrapper +from jaxrl_m.utils.train_utils import concat_batches + +from edgeml.trainer import TrainerServer, TrainerClient, TrainerTunnel +from edgeml.data.data_store import QueuedDataStore + +from serl_launcher.utils.jaxrl_m_common import ( + MemoryEfficientReplayBufferDataStore, + make_drq_agent, + make_trainer_config, + make_wandb_logger, +) +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.data.serl_memory_efficient_replay_buffer import ( + MemoryEfficientReplayBuffer, +) + +FLAGS = flags.FLAGS + +flags.DEFINE_string("env", "HalfCheetah-v4", "Name of environment.") +flags.DEFINE_string("agent", "drq", "Name of agent.") +flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_integer("max_traj_length", 1000, "Maximum length of trajectory.") +flags.DEFINE_integer("seed", 42, "Random seed.") +flags.DEFINE_bool("save_model", False, "Whether to save model.") +flags.DEFINE_integer("batch_size", 256, "Batch size.") +flags.DEFINE_integer("utd_ratio", 4, "UTD ratio.") + +flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") +flags.DEFINE_integer("replay_buffer_capacity", 200000, "Replay buffer capacity.") + +flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.") +flags.DEFINE_integer("training_starts", 300, "Training starts after this step.") +flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.") + +flags.DEFINE_integer("log_period", 10, "Logging period.") +flags.DEFINE_integer("eval_period", 2000, "Evaluation period.") +flags.DEFINE_integer("eval_n_trajs", 5, "Number of trajectories for evaluation.") + +# flag to indicate if this is a leaner or a actor +flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.") +flags.DEFINE_boolean("actor", False, "Is this a learner or a trainer.") +flags.DEFINE_boolean("render", False, "Render the environment.") +flags.DEFINE_string("ip", "localhost", "IP address of the learner.") + +flags.DEFINE_boolean( + "debug", False, "Debug mode." +) # debug mode will disable wandb logging + + +def print_green(x): + return print("\033[92m {}\033[00m".format(x)) + + +############################################################################## + + +def actor(agent: DrQAgent, data_store, env, sampling_rng, tunnel=None): + """ + This is the actor loop, which runs when "--actor" is set to True. + NOTE: tunnel is used the transport layer for multi-threading + """ + if tunnel: + client = tunnel + else: + client = TrainerClient( + "actor_env", + FLAGS.ip, + make_trainer_config(), + data_store, + wait_for_server=True, + ) + + # Function to update the agent with new params + def update_params(params): + nonlocal agent + agent = agent.replace(state=agent.state.replace(params=params)) + + client.recv_network_callback(update_params) + + eval_env = gym.make(FLAGS.env) + if FLAGS.env == "PandaPickCubeVision-v0": + eval_env = SERLObsWrapper(eval_env) + eval_env = ChunkingWrapper(eval_env, obs_horizon=1, act_exec_horizon=None) + eval_env = RecordEpisodeStatistics(eval_env) + + obs, _ = env.reset() + done = False + + # training loop + timer = Timer() + running_return = 0.0 + + for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): + timer.tick("total") + + with timer.context("sample_actions"): + if step < FLAGS.random_steps: + actions = env.action_space.sample() + else: + sampling_rng, key = jax.random.split(sampling_rng) + actions = agent.sample_actions( + observations=jax.device_put(obs), + seed=key, + deterministic=False, + ) + actions = np.asarray(jax.device_get(actions)) + + # Step environment + with timer.context("step_env"): + + next_obs, reward, done, truncated, info = env.step(actions) + reward = np.asarray(reward, dtype=np.float32) + info = np.asarray(info) + running_return += reward + transition = dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=reward, + masks=1.0 - done, + dones=done, + ) + data_store.insert(transition) + + obs = next_obs + if done or truncated: + running_return = 0.0 + obs, _ = env.reset() + + if step % FLAGS.steps_per_update == 0: + client.update() + + if step % FLAGS.eval_period == 0: + with timer.context("eval"): + evaluate_info = evaluate( + policy_fn=partial(agent.sample_actions, argmax=True), + env=eval_env, + num_episodes=FLAGS.eval_n_trajs, + ) + stats = {"eval": evaluate_info} + client.request("send-stats", stats) + + timer.tock("total") + + if step % FLAGS.log_period == 0: + stats = {"timer": timer.get_average_times()} + client.request("send-stats", stats) + + +############################################################################## + + +def learner( + rng, + agent: DrQAgent, + replay_buffer, + replay_iterator, + demo_iterator=None, + wandb_logger=None, + tunnel=None, +): + """ + The learner loop, which runs when "--learner" is set to True. + NOTE: tunnel is used the transport layer for multi-threading + """ + # To track the step in the training loop + update_steps = 0 + + def stats_callback(type: str, payload: dict) -> dict: + """Callback for when server receives stats request.""" + assert type == "send-stats", f"Invalid request type: {type}" + if wandb_logger is not None: + wandb_logger.log(payload, step=update_steps) + return {} # not expecting a response + + # Create server + if tunnel: + tunnel.register_request_callback(stats_callback) + server = tunnel + else: + server = TrainerServer(make_trainer_config(), request_callback=stats_callback) + server.register_data_store("actor_env", replay_buffer) + server.start(threaded=True) + + # Loop to wait until replay_buffer is filled + pbar = tqdm.tqdm( + total=FLAGS.training_starts, + initial=len(replay_buffer), + desc="Filling up replay buffer", + position=0, + leave=True, + ) + while len(replay_buffer) < FLAGS.training_starts: + pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + time.sleep(1) + pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + pbar.close() + + # send the initial network to the actor + server.publish_network(agent.state.params) + print_green("sent initial network to actor") + + # wait till the replay buffer is filled with enough data + timer = Timer() + for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True, desc="learner"): + # run n-1 critic updates and 1 critic + actor update. + # This makes training on GPU faster by reducing the large batch transfer time from CPU to GPU + for critic_step in range(FLAGS.utd_ratio - 1): + with timer.context("sample_replay_buffer"): + batch = next(replay_iterator) + if demo_iterator is not None: + demo_batch = next(demo_iterator) + batch = concat_batches(batch, demo_batch, axis=0) + + with timer.context("train_critics"): + agent, critics_info = agent.update_critics( + batch, + ) + agent = jax.block_until_ready(agent) + + with timer.context("train"): + batch = next(replay_iterator) + if demo_iterator is not None: + demo_batch = next(demo_iterator) + batch = concat_batches(batch, demo_batch, axis=0) + agent, update_info = agent.update_high_utd(batch, utd_ratio=1) + agent = jax.block_until_ready(agent) + + # publish the updated network + if step > 0 and step % (FLAGS.steps_per_update) == 0: + server.publish_network(agent.state.params) + + if update_steps % FLAGS.log_period == 0 and wandb_logger: + wandb_logger.log(update_info, step=update_steps) + wandb_logger.log({"timer": timer.get_average_times()}, step=update_steps) + + update_steps += 1 + + +############################################################################## + + +def main(_): + devices = jax.local_devices() + num_devices = len(devices) + sharding = jax.sharding.PositionalSharding(devices) + assert FLAGS.batch_size % num_devices == 0 + + # seed + rng = jax.random.PRNGKey(FLAGS.seed) + + # create env and load dataset + if FLAGS.render: + env = gym.make(FLAGS.env, render_mode="human") + else: + env = gym.make(FLAGS.env) + + if FLAGS.env == "PandaPickCube-v0": + env = gym.wrappers.FlattenObservation(env) + if FLAGS.env == "PandaPickCubeVision-v0": + env = SERLObsWrapper(env) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + + image_keys = [key for key in env.observation_space.keys() if key != "state"] + + rng, sampling_rng = jax.random.split(rng) + agent: DrQAgent = make_drq_agent( + seed=FLAGS.seed, + sample_obs=env.observation_space.sample(), + sample_action=env.action_space.sample(), + image_keys=image_keys, + encoder_type="mobilenet", + ) + + # replicate agent across devices + # need the jnp.array to avoid a bug where device_put doesn't recognize primitives + agent: DrQAgent = jax.device_put( + jax.tree_map(jnp.array, agent), sharding.replicate() + ) + + def create_replay_buffer_and_wandb_logger(): + replay_buffer = MemoryEfficientReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=FLAGS.replay_buffer_capacity, + image_keys=image_keys, + ) + # set up wandb and logging + wandb_logger = make_wandb_logger( + project="serl_dev", + description=FLAGS.exp_name or FLAGS.env, + debug=FLAGS.debug, + ) + return replay_buffer, wandb_logger + + if FLAGS.learner: + sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) + replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger() + replay_iterator = replay_buffer.get_iterator( + sample_args={ + "batch_size": FLAGS.batch_size // 2, + "pack_obs_and_next_obs": True, + }, + device=sharding.replicate(), + ) + demo_buffer = MemoryEfficientReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=10000, + image_keys=image_keys, + ) + demo_iterator = demo_buffer.get_iterator( + sample_args={ + "batch_size": FLAGS.batch_size // 2, + "pack_obs_and_next_obs": True, + }, + device=sharding.replicate(), + ) + import pickle as pkl + + with open("trajs.pkl", "rb") as f: + trajs = pkl.load(f) + for traj in trajs: + demo_buffer.insert(traj) + print(f"replay buffer size: {len(demo_buffer)}") + + # learner loop + print_green("starting learner loop") + learner( + sampling_rng, + agent, + replay_buffer, + replay_iterator=replay_iterator, + demo_iterator=demo_iterator, + wandb_logger=wandb_logger, + tunnel=None, + ) + + elif FLAGS.actor: + sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) + data_store = QueuedDataStore(50000) # the queue size on the actor + + # actor loop + print_green("starting actor loop") + actor(agent, data_store, env, sampling_rng, tunnel=None) + + else: + print_green("starting actor and learner loop with multi-threading") + + # In this example, the tunnel acts as the transport layer for the + # trainerServer and trainerClient. Also, both actor and learner shares + # the same replay buffer. + replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger() + + tunnel = TrainerTunnel() + sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) + + import threading + + # Start learner thread + learner_thread = threading.Thread( + target=learner, args=(agent, replay_buffer, wandb_logger, tunnel) + ) + learner_thread.start() + + # Start actor in main process + actor(agent, replay_buffer, env, sampling_rng, tunnel=tunnel) + learner_thread.join() + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/async_rlpd_drq_sim/run_actor.sh b/examples/async_rlpd_drq_sim/run_actor.sh new file mode 100644 index 00000000..06219b03 --- /dev/null +++ b/examples/async_rlpd_drq_sim/run_actor.sh @@ -0,0 +1,14 @@ +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ +python async_rlpd_drq_sim.py \ + --actor \ + --render \ + --env PandaPickCubeVision-v0 \ + --exp_name=serl_dev_rlpd_drq_sim_test_mobilenet \ + --seed 0 \ + --random_steps 1000 \ + --training_starts 1000 \ + --utd_ratio 4 \ + --batch_size 256 \ + --eval_period 2000 \ + --debug diff --git a/examples/async_rlpd_drq_sim/run_learner.sh b/examples/async_rlpd_drq_sim/run_learner.sh new file mode 100644 index 00000000..7a834a5a --- /dev/null +++ b/examples/async_rlpd_drq_sim/run_learner.sh @@ -0,0 +1,13 @@ +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ +python async_rlpd_drq_sim.py \ + --learner \ + --env PandaPickCubeVision-v0 \ + --exp_name=serl_dev_rlpd_drq_sim_test_mobilenet \ + --seed 0 \ + --random_steps 1000 \ + --training_starts 1000 \ + --utd_ratio 4 \ + --batch_size 256 \ + --eval_period 2000 \ + --debug