Skip to content

Commit

Permalink
[RLlib] MultiAgentEpisode for Multi-Agent Reinforcement Learning wi…
Browse files Browse the repository at this point in the history
…th the new `EnvRunner` API. (ray-project#40263)
  • Loading branch information
simonsays1980 authored Oct 30, 2023
1 parent f0fe37e commit 1ee167d
Show file tree
Hide file tree
Showing 11 changed files with 1,938 additions and 382 deletions.
21 changes: 14 additions & 7 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -760,13 +760,6 @@ py_test(
# Tag: env
# --------------------------------------------------------------------

py_test(
name = "env/tests/test_single_agent_gym_env_runner",
tags = ["team:rllib", "env"],
size = "medium",
srcs = ["env/tests/test_single_agent_gym_env_runner.py"]
)

py_test(
name = "env/tests/test_env_with_subprocess",
tags = ["team:rllib", "env"],
Expand Down Expand Up @@ -869,6 +862,20 @@ py_test(
srcs = ["env/tests/test_remote_worker_envs.py"]
)

py_test(
name = "env/tests/test_single_agent_gym_env_runner",
tags = ["team:rllib", "env"],
size = "medium",
srcs = ["env/tests/test_single_agent_gym_env_runner.py"]
)

py_test(
name = "env/tests/test_single_agent_episode",
tags = ["team:rllib", "env"],
size = "medium",
srcs = ["env/tests/test_single_agent_episode.py"]
)

py_test(
name = "env/wrappers/tests/test_exception_wrapper",
tags = ["team:rllib", "env"],
Expand Down
21 changes: 12 additions & 9 deletions rllib/algorithms/dreamerv3/utils/env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.replay_buffers.episode_replay_buffer import _Episode as Episode
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.numpy import one_hot
from ray.tune.registry import ENV_CREATOR, _global_registry

Expand Down Expand Up @@ -163,7 +163,7 @@ def sample(
explore: bool = True,
random_actions: bool = False,
with_render_data: bool = False,
) -> Tuple[List[Episode], List[Episode]]:
) -> Tuple[List[SingleAgentEpisode], List[SingleAgentEpisode]]:
"""Runs and returns a sample (n timesteps or m episodes) on the environment(s).
Timesteps or episodes are counted in total (across all vectorized
Expand Down Expand Up @@ -229,7 +229,7 @@ def _sample_timesteps(
explore: bool = True,
random_actions: bool = False,
force_reset: bool = False,
) -> Tuple[List[Episode], List[Episode]]:
) -> Tuple[List[SingleAgentEpisode], List[SingleAgentEpisode]]:
"""Helper method to run n timesteps.
See docstring of self.sample() for more details.
Expand All @@ -246,7 +246,7 @@ def _sample_timesteps(
if force_reset or self._needs_initial_reset:
obs, _ = self.env.reset()

self._episodes = [Episode() for _ in range(self.num_envs)]
self._episodes = [SingleAgentEpisode() for _ in range(self.num_envs)]
states = initial_states
# Set is_first to True for all rows (all sub-envs just got reset).
is_first = np.ones((self.num_envs,))
Expand All @@ -263,7 +263,8 @@ def _sample_timesteps(
# Pick up stored observations and states from previous timesteps.
obs = np.stack([eps.observations[-1] for eps in self._episodes])
# Compile the initial state for each batch row: If episode just started, use
# model's initial state, if not, use state stored last in Episode.
# model's initial state, if not, use state stored last in
# SingleAgentEpisode.
states = {
k: np.stack(
[
Expand Down Expand Up @@ -333,7 +334,9 @@ def _sample_timesteps(
is_first[i] = True
done_episodes_to_return.append(self._episodes[i])
# Create a new episode object.
self._episodes[i] = Episode(observations=[obs[i]], states=s)
self._episodes[i] = SingleAgentEpisode(
observations=[obs[i]], states=s
)
else:
self._episodes[i].add_timestep(
obs[i], actions[i], rewards[i], state=s
Expand All @@ -360,15 +363,15 @@ def _sample_episodes(
explore: bool = True,
random_actions: bool = False,
with_render_data: bool = False,
) -> List[Episode]:
) -> List[SingleAgentEpisode]:
"""Helper method to run n episodes.
See docstring of `self.sample()` for more details.
"""
done_episodes_to_return = []

obs, _ = self.env.reset()
episodes = [Episode() for _ in range(self.num_envs)]
episodes = [SingleAgentEpisode() for _ in range(self.num_envs)]

# Multiply states n times according to our vector env batch size (num_envs).
states = tree.map_structure(
Expand Down Expand Up @@ -443,7 +446,7 @@ def _sample_episodes(
states[k][i] = v.numpy()
is_first[i] = True

episodes[i] = Episode(
episodes[i] = SingleAgentEpisode(
observations=[obs[i]],
states=s,
render_images=[render_images[i]],
Expand Down
21 changes: 13 additions & 8 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
SAMPLE_TIMER,
ALL_MODULES,
)
from ray.rllib.utils.replay_buffers.episode_replay_buffer import _Episode as Episode
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.typing import ResultDict
from ray.util.debug import log_once
Expand Down Expand Up @@ -443,17 +444,19 @@ def training_step(self) -> ResultDict:
# New Episode-returning EnvRunner API.
else:
if self.workers.num_remote_workers() <= 0:
episodes = [self.workers.local_worker().sample()]
episodes: List[SingleAgentEpisode] = [
self.workers.local_worker().sample()
]
else:
episodes = self.workers.foreach_worker(
episodes: List[SingleAgentEpisode] = self.workers.foreach_worker(
lambda w: w.sample(), local_worker=False
)
# Perform PPO postprocessing on a (flattened) list of Episodes.
postprocessed_episodes = self.postprocess_episodes(
tree.flatten(episodes)
)
postprocessed_episodes: List[
SingleAgentEpisode
] = self.postprocess_episodes(tree.flatten(episodes))
# Convert list of postprocessed Episodes into a single sample batch.
train_batch = postprocess_episodes_to_sample_batch(
train_batch: SampleBatch = postprocess_episodes_to_sample_batch(
postprocessed_episodes
)

Expand Down Expand Up @@ -610,7 +613,9 @@ def training_step(self) -> ResultDict:

return train_results

def postprocess_episodes(self, episodes: List[Episode]) -> List[Episode]:
def postprocess_episodes(
self, episodes: List[SingleAgentEpisode]
) -> List[SingleAgentEpisode]:
"""Calculate advantages and value targets."""
from ray.rllib.evaluation.postprocessing_v2 import compute_gae_for_episode

Expand Down
Loading

0 comments on commit 1ee167d

Please sign in to comment.