From 1ee167d4e7c1f0ff93f75be4648aeed19efffefe Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Mon, 30 Oct 2023 14:10:14 +0100 Subject: [PATCH] [RLlib] `MultiAgentEpisode` for Multi-Agent Reinforcement Learning with the new `EnvRunner` API. (#40263) --- rllib/BUILD | 21 +- .../algorithms/dreamerv3/utils/env_runner.py | 21 +- rllib/algorithms/ppo/ppo.py | 21 +- rllib/env/multi_agent_episode.py | 872 ++++++++++++++++++ rllib/env/single_agent_env_runner.py | 34 +- rllib/env/single_agent_episode.py | 528 +++++++++++ .../testing/single_agent_gym_env_runner.py | 24 +- rllib/env/tests/test_single_agent_episode.py | 453 +++++++++ rllib/evaluation/postprocessing_v2.py | 14 +- .../replay_buffers/episode_replay_buffer.py | 326 +------ .../tests/test_episode_replay_buffer.py | 6 +- 11 files changed, 1938 insertions(+), 382 deletions(-) create mode 100644 rllib/env/multi_agent_episode.py create mode 100644 rllib/env/single_agent_episode.py create mode 100644 rllib/env/tests/test_single_agent_episode.py diff --git a/rllib/BUILD b/rllib/BUILD index 7854a6bcc77f..833ac21f6add 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -760,13 +760,6 @@ py_test( # Tag: env # -------------------------------------------------------------------- -py_test( - name = "env/tests/test_single_agent_gym_env_runner", - tags = ["team:rllib", "env"], - size = "medium", - srcs = ["env/tests/test_single_agent_gym_env_runner.py"] -) - py_test( name = "env/tests/test_env_with_subprocess", tags = ["team:rllib", "env"], @@ -869,6 +862,20 @@ py_test( srcs = ["env/tests/test_remote_worker_envs.py"] ) +py_test( + name = "env/tests/test_single_agent_gym_env_runner", + tags = ["team:rllib", "env"], + size = "medium", + srcs = ["env/tests/test_single_agent_gym_env_runner.py"] +) + +py_test( + name = "env/tests/test_single_agent_episode", + tags = ["team:rllib", "env"], + size = "medium", + srcs = ["env/tests/test_single_agent_episode.py"] +) + py_test( name = "env/wrappers/tests/test_exception_wrapper", tags = ["team:rllib", "env"], diff --git a/rllib/algorithms/dreamerv3/utils/env_runner.py b/rllib/algorithms/dreamerv3/utils/env_runner.py index 60898e936d27..259a27e4f7df 100644 --- a/rllib/algorithms/dreamerv3/utils/env_runner.py +++ b/rllib/algorithms/dreamerv3/utils/env_runner.py @@ -25,7 +25,7 @@ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.replay_buffers.episode_replay_buffer import _Episode as Episode +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.numpy import one_hot from ray.tune.registry import ENV_CREATOR, _global_registry @@ -163,7 +163,7 @@ def sample( explore: bool = True, random_actions: bool = False, with_render_data: bool = False, - ) -> Tuple[List[Episode], List[Episode]]: + ) -> Tuple[List[SingleAgentEpisode], List[SingleAgentEpisode]]: """Runs and returns a sample (n timesteps or m episodes) on the environment(s). Timesteps or episodes are counted in total (across all vectorized @@ -229,7 +229,7 @@ def _sample_timesteps( explore: bool = True, random_actions: bool = False, force_reset: bool = False, - ) -> Tuple[List[Episode], List[Episode]]: + ) -> Tuple[List[SingleAgentEpisode], List[SingleAgentEpisode]]: """Helper method to run n timesteps. See docstring of self.sample() for more details. @@ -246,7 +246,7 @@ def _sample_timesteps( if force_reset or self._needs_initial_reset: obs, _ = self.env.reset() - self._episodes = [Episode() for _ in range(self.num_envs)] + self._episodes = [SingleAgentEpisode() for _ in range(self.num_envs)] states = initial_states # Set is_first to True for all rows (all sub-envs just got reset). is_first = np.ones((self.num_envs,)) @@ -263,7 +263,8 @@ def _sample_timesteps( # Pick up stored observations and states from previous timesteps. obs = np.stack([eps.observations[-1] for eps in self._episodes]) # Compile the initial state for each batch row: If episode just started, use - # model's initial state, if not, use state stored last in Episode. + # model's initial state, if not, use state stored last in + # SingleAgentEpisode. states = { k: np.stack( [ @@ -333,7 +334,9 @@ def _sample_timesteps( is_first[i] = True done_episodes_to_return.append(self._episodes[i]) # Create a new episode object. - self._episodes[i] = Episode(observations=[obs[i]], states=s) + self._episodes[i] = SingleAgentEpisode( + observations=[obs[i]], states=s + ) else: self._episodes[i].add_timestep( obs[i], actions[i], rewards[i], state=s @@ -360,7 +363,7 @@ def _sample_episodes( explore: bool = True, random_actions: bool = False, with_render_data: bool = False, - ) -> List[Episode]: + ) -> List[SingleAgentEpisode]: """Helper method to run n episodes. See docstring of `self.sample()` for more details. @@ -368,7 +371,7 @@ def _sample_episodes( done_episodes_to_return = [] obs, _ = self.env.reset() - episodes = [Episode() for _ in range(self.num_envs)] + episodes = [SingleAgentEpisode() for _ in range(self.num_envs)] # Multiply states n times according to our vector env batch size (num_envs). states = tree.map_structure( @@ -443,7 +446,7 @@ def _sample_episodes( states[k][i] = v.numpy() is_first[i] = True - episodes[i] = Episode( + episodes[i] = SingleAgentEpisode( observations=[obs[i]], states=s, render_images=[render_images[i]], diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 6bd25c6d62e9..900fac4d163d 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -49,7 +49,8 @@ SAMPLE_TIMER, ALL_MODULES, ) -from ray.rllib.utils.replay_buffers.episode_replay_buffer import _Episode as Episode +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ResultDict from ray.util.debug import log_once @@ -443,17 +444,19 @@ def training_step(self) -> ResultDict: # New Episode-returning EnvRunner API. else: if self.workers.num_remote_workers() <= 0: - episodes = [self.workers.local_worker().sample()] + episodes: List[SingleAgentEpisode] = [ + self.workers.local_worker().sample() + ] else: - episodes = self.workers.foreach_worker( + episodes: List[SingleAgentEpisode] = self.workers.foreach_worker( lambda w: w.sample(), local_worker=False ) # Perform PPO postprocessing on a (flattened) list of Episodes. - postprocessed_episodes = self.postprocess_episodes( - tree.flatten(episodes) - ) + postprocessed_episodes: List[ + SingleAgentEpisode + ] = self.postprocess_episodes(tree.flatten(episodes)) # Convert list of postprocessed Episodes into a single sample batch. - train_batch = postprocess_episodes_to_sample_batch( + train_batch: SampleBatch = postprocess_episodes_to_sample_batch( postprocessed_episodes ) @@ -610,7 +613,9 @@ def training_step(self) -> ResultDict: return train_results - def postprocess_episodes(self, episodes: List[Episode]) -> List[Episode]: + def postprocess_episodes( + self, episodes: List[SingleAgentEpisode] + ) -> List[SingleAgentEpisode]: """Calculate advantages and value targets.""" from ray.rllib.evaluation.postprocessing_v2 import compute_gae_for_episode diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py new file mode 100644 index 000000000000..e64edbd08a59 --- /dev/null +++ b/rllib/env/multi_agent_episode.py @@ -0,0 +1,872 @@ +import numpy as np +import uuid + +from typing import Any, Dict, List, Optional, Union + +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils.typing import MultiAgentDict + + +# TODO (simon): Include cases in which the number of agents in an +# episode are shrinking or growing during the episode itself. +class MultiAgentEpisode: + """Stores multi-agent episode data. + + The central attribute of the class is the timestep mapping + `global_t_to_local_t` that maps the global (environment) + timestep to the local (agent) timesteps. + + The `MultiAgentEpisode` is based on the `SingleAgentEpisode`s + for each agent, stored in `MultiAgentEpisode.agent_episodes`. + """ + + def __init__( + self, + id_: Optional[str] = None, + agent_ids: List[str] = None, + agent_episode_ids: Optional[Dict[str, str]] = None, + *, + observations: Optional[List[MultiAgentDict]] = None, + actions: Optional[List[MultiAgentDict]] = None, + rewards: Optional[List[MultiAgentDict]] = None, + states: Optional[List[MultiAgentDict]] = None, + infos: Optional[List[MultiAgentDict]] = None, + t_started: int = 0, + is_terminated: Optional[bool] = False, + is_truncated: Optional[bool] = False, + render_images: Optional[List[np.ndarray]] = None, + extra_model_outputs: Optional[List[MultiAgentDict]] = None, + # TODO (simon): Also allow to receive `extra_model_outputs`. + # TODO (simon): Validate terminated/truncated for env/agents. + ) -> "MultiAgentEpisode": + """Initializes a `MultiAgentEpisode`. + + Args: + id_: Optional. Either a string to identify an episode or None. + If None, a hexadecimal id is created. In case of providing + a string, make sure that it is unique, as episodes get + concatenated via this string. + agent_ids: Obligatory. A list of strings containing the agent ids. + These have to be provided at initialization. + agent_episode_ids: Optional. Either a dictionary mapping agent ids + corresponding `SingleAgentEpisode` or None. If None, each + `SingleAgentEpisode` in `MultiAgentEpisode.agent_episodes` + will generate a hexadecimal code. If a dictionary is provided + make sure that ids are unique as agents' `SingleAgentEpisode`s + get concatenated or recreated by it. + observations: A dictionary mapping from agent ids to observations. + Can be None. If provided, it should be provided together with + all other episode data (actions, rewards, etc.) + actions: A dictionary mapping from agent ids to corresponding actions. + Can be None. If provided, it should be provided together with + all other episode data (observations, rewards, etc.). + rewards: A dictionary mapping from agent ids to corresponding rewards. + Can be None. If provided, it should be provided together with + all other episode data (observations, rewards, etc.). + infos: A dictionary mapping from agent ids to corresponding infos. + Can be None. If provided, it should be provided together with + all other episode data (observations, rewards, etc.). + states: A dictionary mapping from agent ids to their corresponding + modules' hidden states. These will be stored into the + `SingleAgentEpisode`s in `MultiAgentEpisode.agent_episodes`. + Can be None. + t_started: Optional. An unsigned int that defines the starting point + of the episode. This is only different from zero, if an ongoing + episode is created. + is_terminazted: Optional. A boolean defining, if an environment has + terminated. The default is `False`, i.e. the episode is ongoing. + is_truncated: Optional. A boolean, defining, if an environment is + truncated. The default is `False`, i.e. the episode is ongoing. + render_images: Optional. A list of RGB uint8 images from rendering + the environment. + extra_model_outputs: Optional. A dictionary mapping agent ids to their + corresponding extra model outputs. Each of the latter is a list of + dictionaries containing specific model outputs for the algorithm + used (e.g. `vf_preds` and `action_logp` for PPO) from a rollout. + If data is provided it should be complete (i.e. observations, + actions, rewards, is_terminated, is_truncated, and all necessary + `extra_model_outputs`). + """ + + self.id_: str = id_ or uuid.uuid4().hex + + # Agent ids must be provided if data is provided. The Episode cannot + # know how many agents are in the environment. Also the number of agents + # can grwo or shrink. + self._agent_ids: Union[List[str], List[object]] = ( + [] if agent_ids is None else agent_ids + ) + + # The global last timestep of the episode and the timesteps when this chunk + # started. + self.t = self.t_started = ( + t_started if t_started is not None else max(len(observations) - 1, 0) + ) + # Keeps track of the correspondence between agent steps and environment steps. + # This is a mapping from agents to `IndexMapping`. The latter keeps + # track of the global timesteps at which an agent stepped. + self.global_t_to_local_t: Dict[str, List[int]] = self._generate_ts_mapping( + observations + ) + + # Note that all attributes will be recorded along the global timestep + # in an multi-agent environment. `SingleAgentEpisodes` + self.agent_episodes: MultiAgentDict = { + agent_id: self._generate_single_agent_episode( + agent_id, + agent_episode_ids, + observations, + actions, + rewards, + infos, + states, + extra_model_outputs, + ) + for agent_id in self._agent_ids + } + + # obs[-1] is the final observation in the episode. + self.is_terminated: bool = is_terminated + # obs[-1] is the last obs in a truncated-by-the-env episode (there will no more + # observations in following chunks for this episode). + self.is_truncated: bool = is_truncated + # RGB uint8 images from rendering the env; the images include the corresponding + # rewards. + assert render_images is None or observations is not None + self.render_images: Union[List[np.ndarray], List[object]] = ( + [] if render_images is None else render_images + ) + + def concat_episode(self, episode_chunk: "MultiAgentEpisode") -> None: + """Adds the given `episode_chunk` to the right side of self. + + For concatenating episodes the following rules hold: + - IDs are identical. + - timesteps match (`t` of `self` matches `t_started` of `episode_chunk`). + + Args: + episode_chunk: `MultiAgentEpsiode` instance that should be concatenated + to `self`. + """ + assert episode_chunk.id_ == self.id_ + assert not self.is_done + # Make sure the timesteps match. + assert self.t == episode_chunk.t_started + + # TODO (simon): Write `validate()` method. + + # Make sure, end matches `episode_chunk`'s beginning for all agents. + observations: MultiAgentDict = self.get_observations() + for agent_id, agent_obs in episode_chunk.get_observations(indices=0): + # Make sure that the same agents stepped at both timesteps. + assert agent_id in observations + assert observations[agent_id] == agent_obs + # Pop out the end for the agents that stepped. + for agent_id in observations: + self.agent_episodes[agent_id].observations.pop() + + # Call the `SingleAgentEpisode`'s `concat_episode()` method for all agents. + for agent_id, agent_eps in self.agent_episodes: + agent_eps[agent_id].concat_episode(episode_chunk.agent_episodes[agent_id]) + # Update our timestep mapping. + # TODO (simon): Check, if we have to cut off here as well. + self.global_t_to_local_t[agent_id][ + :-1 + ] += episode_chunk.global_t_to_local_t[agent_id] + + self.t = episode_chunk.t + if episode_chunk.is_terminated: + self.is_terminated = True + if episode_chunk.is_truncated: + self.is_truncated = True + + # Validate + # TODO (simon): Write validate function. + # self.validate() + + # TODO (simon): Maybe adding agent axis. We might need only some agent observations. + # Then also add possibility to get __all__ obs (or None) + # Write many test cases (numbered obs). + def get_observations( + self, indices: Union[int, List[int]] = -1, global_ts: bool = True + ) -> MultiAgentDict: + """Gets observations for all agents that stepped in the last timesteps. + + Note that observations are only returned for agents that stepped + during the given index range. + + Args: + indices: Either a single index or a list of indices. The indices + can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]). + This defines the time indices for which the observations + should be returned. + global_ts: Boolean that defines, if the indices should be considered + environment (`True`) or agent (`False`) steps. + + Returns: A dictionary mapping agent ids to observations (of different + timesteps). Only for agents that have stepped (were ready) at a + timestep, observations are returned (i.e. not all agent ids are + necessarily in the keys). + """ + + return self._getattr_by_index("observations", indices, global_ts) + + def get_actions( + self, indices: Union[int, List[int]] = -1, global_ts: bool = True + ) -> MultiAgentDict: + """Gets actions for all agents that stepped in the last timesteps. + + Note that actions are only returned for agents that stepped + during the given index range. + + Args: + indices: Either a single index or a list of indices. The indices + can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]). + This defines the time indices for which the actions + should be returned. + global_ts: Boolean that defines, if the indices should be considered + environment (`True`) or agent (`False`) steps. + + Returns: A dictionary mapping agent ids to actions (of different + timesteps). Only for agents that have stepped (were ready) at a + timestep, actions are returned (i.e. not all agent ids are + necessarily in the keys). + """ + + return self._getattr_by_index("actions", indices, global_ts) + + def get_rewards( + self, indices: Union[int, List[int]] = -1, global_ts: bool = True + ) -> MultiAgentDict: + """Gets rewards for all agents that stepped in the last timesteps. + + Note that rewards are only returned for agents that stepped + during the given index range. + + Args: + indices: Either a single index or a list of indices. The indices + can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]). + This defines the time indices for which the rewards + should be returned. + global_ts: Boolean that defines, if the indices should be considered + environment (`True`) or agent (`False`) steps. + + Returns: A dictionary mapping agent ids to rewards (of different + timesteps). Only for agents that have stepped (were ready) at a + timestep, rewards are returned (i.e. not all agent ids are + necessarily in the keys). + """ + return self._getattr_by_index("rewards", indices, global_ts) + + def get_infos( + self, indices: Union[int, List[int]] = -1, global_ts: bool = True + ) -> MultiAgentDict: + """Gets infos for all agents that stepped in the last timesteps. + + Note that infos are only returned for agents that stepped + during the given index range. + + Args: + indices: Either a single index or a list of indices. The indices + can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]). + This defines the time indices for which the infos + should be returned. + global_ts: Boolean that defines, if the indices should be considered + environment (`True`) or agent (`False`) steps. + + Returns: A dictionary mapping agent ids to infos (of different + timesteps). Only for agents that have stepped (were ready) at a + timestep, infos are returned (i.e. not all agent ids are + necessarily in the keys). + """ + return self._getattr_by_index("infos", indices, global_ts) + + def get_extra_model_outputs( + self, indices: Union[int, List[int]] = -1, global_ts: bool = True + ) -> MultiAgentDict: + """Gets extra model outputs for all agents that stepped in the last timesteps. + + Note that extra model outputs are only returned for agents that stepped + during the given index range. + + Args: + indices: Either a single index or a list of indices. The indices + can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]). + This defines the time indices for which the extra model outputs. + should be returned. + global_ts: Boolean that defines, if the indices should be considered + environment (`True`) or agent (`False`) steps. + + Returns: A dictionary mapping agent ids to extra model outputs (of different + timesteps). Only for agents that have stepped (were ready) at a + timestep, extra model outputs are returned (i.e. not all agent ids are + necessarily in the keys). + """ + return self._getattr_by_index("extra_model_outputs", indices, global_ts) + + def add_initial_observation( + self, + *, + initial_observation: MultiAgentDict, + initial_info: Optional[MultiAgentDict] = None, + initial_state: Optional[MultiAgentDict] = None, + initial_render_image: Optional[np.ndarray] = None, + ) -> None: + """Stores initial observation. + + Args: + initial_observation: Obligatory. A dictionary, mapping agent ids + to initial observations. Note that not all agents must have + an initial observation. + initial_info: Optional. A dictionary, mapping agent ids to initial + infos. Note that not all agents must have an initial info. + initial_state: Optional. A dictionary, mapping agent ids to the + initial hidden states of their corresponding model (`RLModule`). + Note, this is only available, if the models are stateful. Note + also that not all agents must have an initial state at `t=0`. + initial_render_image: An RGB uint8 image from rendering the + environment. + """ + assert not self.is_done + # Assume that this episode is completely empty and has not stepped yet. + # Leave self.t (and self.t_started) at 0. + assert self.t == self.t_started == 0 + + # TODO (simon): After clearing with sven for initialization of timesteps + # this might be removed. + if len(self.global_t_to_local_t) == 0: + self.global_t_to_local_t = {agent_id: [] for agent_id in self._agent_ids} + + # Note that we store the render images into the `MultiAgentEpisode` + # instead into each `SingleAgentEpisode`. + if initial_render_image is not None: + self.render_images.append(initial_render_image) + + # Note, all agents will have an initial observation. + for agent_id in initial_observation.keys(): + # Add initial timestep for each agent to the timestep mapping. + self.global_t_to_local_t[agent_id].append(self.t) + + # Add initial observations to the agent's episode. + self.agent_episodes[agent_id].add_initial_observation( + # Note, initial observation has to be provided. + initial_observation=initial_observation[agent_id], + initial_info=None if initial_info is None else initial_info[agent_id], + initial_state=None + if initial_state is None + else initial_state[agent_id], + ) + + def add_timestep( + self, + observation: MultiAgentDict, + action: MultiAgentDict, + reward: MultiAgentDict, + *, + info: Optional[MultiAgentDict] = None, + state: Optional[MultiAgentDict] = None, + is_terminated: Optional[bool] = None, + is_truncated: Optional[bool] = None, + render_image: Optional[np.ndarray] = None, + extra_model_output: Optional[MultiAgentDict] = None, + ) -> None: + """Adds a timestep to the episode. + + Args: + observation: Mandatory. A dictionary, mapping agent ids to their + corresponding observations. Note that not all agents must have stepped + a this timestep. + action: Mandatory. A dictionary, mapping agent ids to their + corresponding actions. Note that not all agents must have stepped + a this timestep. + reward: Mandatory. A dictionary, mapping agent ids to their + corresponding observations. Note that not all agents must have stepped + a this timestep. + info: Optional. A dictionary, mapping agent ids to their + corresponding info. Note that not all agents must have stepped + a this timestep. + state: Optional. A dictionary, mapping agent ids to their + corresponding hidden model state. Note, this is only available for a + stateful model. Also note that not all agents must have stepped a this + timestep. + is_terminated: A boolean indicating, if the environment has been + terminated. + is_truncated: A boolean indicating, if the environment has been + truncated. + render_image: Optional. An RGB uint8 image from rendering the environment. + extra_model_output: Optional. A dictionary, mapping agent ids to their + corresponding specific model outputs (also in a dictionary; e.g. + `vf_preds` for PPO). + """ + # Cannot add data to an already done episode. + assert not self.is_done + + # Environment step. + self.t += 1 + + # TODO (sven, simon): Wilol there still be an `__all__` that is + # terminated or truncated? + self.is_terminated = ( + False if is_terminated is None else is_terminated["__all__"] + ) + self.is_truncated = False if is_truncated is None else is_truncated["__all__"] + + # Add data to agent episodes. + for agent_id in observation.keys(): + # If the agent stepped we need to keep track in the timestep mapping. + self.global_t_to_local_t[agent_id].append(self.t) + + # Note that we store the render images into the `MultiAgentEpisode` + # instead of storing them into each `SingleAgentEpisode`. + if render_image is not None: + self.render_images.append(render_image) + + self.agent_episodes[agent_id].add_timestep( + observation[agent_id], + action[agent_id], + reward[agent_id], + info=None if info is None else info[agent_id], + state=None if state is None else state[agent_id], + is_terminated=None + if is_terminated is None + else is_terminated[agent_id], + is_truncated=None if is_truncated is None else is_truncated[agent_id], + render_image=None if render_image is None else render_image[agent_id], + extra_model_output=None + if extra_model_output is None + else extra_model_output[agent_id], + ) + + @property + def is_done(self): + """Whether the episode is actually done (terminated or truncated). + + A done episode cannot be continued via `self.add_timestep()` or being + concatenated on its right-side with another episode chunk or being + succeeded via `self.create_successor()`. + + Note that in a multi-agent environment this does not necessarily + correspond to single agents having terminated or being truncated. + + `self.is_terminated` should be `True`, if all agents are terminated and + `self.is_truncated` should be `True`, if all agents are truncated. If + only one or more (but not all!) agents are `terminated/truncated the + `MultiAgentEpisode.is_terminated/is_truncated` should be `False`. This + information about single agent's terminated/truncated states can always + be retrieved from the `SingleAgentEpisode`s inside the 'MultiAgentEpisode` + one. + + If all agents are either terminated or truncated, but in a mixed fashion, + i.e. some are terminated and others are truncated: This is currently + undefined and could potentially be a problem (if a user really implemented + such a multi-agent env that behaves this way). + + Returns: + Boolean defining if an episode has either terminated or truncated. + """ + return self.is_terminated or self.is_truncated + + def create_successor(self) -> "MultiAgentEpisode": + """Restarts an ongoing episode from its last observation. + + Note, this method is used so far specifically for the case of + `batch_mode="truncated_episodes"` to ensure that episodes are + immutable inside the `EnvRunner` when truncated and passed over + to postprocessing. + + The newly created `MultiAgentEpisode` contains the same id, and + starts at the timestep where it's predecessor stopped in the last + rollout. Last observations, infos, rewards, etc. are carried over + from the predecessor. This also helps to not carry stale data that + had been collected in the last rollout when rolling out the new + policy in the next iteration (rollout). + + Returns: A MultiAgentEpisode starting at the timepoint where + its predecessor stopped. + """ + assert not self.is_done + + # Get the last multi-agent observation and info. + observations = self.get_observations() + infos = self.get_infos() + # It is more safe to use here a list of episode ids instead of + # calling `create_successor()` as we need as the single source + # of truth always the `global_t_to_local_t` timestep mapping. + return MultiAgentEpisode( + id=self.id_, + agent_episode_ids={ + agent_id: agent_eps.id_ for agent_id, agent_eps in self.agent_episodes + }, + observations=observations, + infos=infos, + t_started=self.t, + ) + + def get_state(self) -> Dict[str, Any]: + """Returns the state of a multi-agent episode. + + Note that from an episode's state the episode itself can + be recreated. + + Returns: A dicitonary containing pickable data fro a + `MultiAgentEpisode`. + """ + return list( + { + "id_": self.id_, + "agent_ids": self._agent_ids, + "global_t_to_local_t": self.global_t_to_local_t, + "agent_episodes": list( + { + agent_id: agent_eps.get_state() + for agent_id, agent_eps in self.agent_episodes.items() + }.items() + ), + "t_started": self.t_started, + "t": self.t, + "is_terminated": self.is_terminated, + "is_truncated": self.is_truncated, + }.items() + ) + + @staticmethod + def from_state(state) -> None: + """Creates a multi-agent episode from a state dictionary. + + See `MultiAgentEpisode.get_state()` for creating a state for + a `MultiAgentEpisode` pickable state. For recreating a + `MultiAgentEpisode` from a state, this state has to be complete, + i.e. all data must have been stored in the state. + """ + eps = MultiAgentEpisode(id=state[0][1]) + eps._agent_ids = state[1][1] + eps.global_t_to_local_t = state[2][1] + eps.agent_episodes = { + agent_id: SingleAgentEpisode.from_state(agent_state) + for agent_id, agent_state in state[3][1] + } + eps.t_started = state[3][1] + eps.t = state[4][1] + eps.is_terminated = state[5][1] + eps.is_trcunated = state[6][1] + return eps + + def to_sample_batch(self) -> MultiAgentBatch: + """Converts a `MultiAgentEpisode` into a `MultiAgentBatch`. + + Each `SingleAgentEpisode` instances in + `MultiAgentEpisode.agent_epiosdes` will be converted into + a `SampleBatch` and the environment timestep will be passed + towards the `MultiAgentBatch`'s `count`. + + Returns: A `MultiAgentBatch` instance. + """ + # TODO (simon): Check, if timesteps should be converted into global + # timesteps instead of agent steps. + return MultiAgentBatch( + policy_batches={ + agent_id: agent_eps.to_sample_batch() + for agent_id, agent_eps in self.agent_episodes.items() + }, + env_steps=self.t, + ) + + def get_return(self) -> float: + """Get the all-agent return. + + Returns: A float. The aggregate return from all agents. + """ + return sum( + [agent_eps.get_return() for agent_eps in self.agent_episodes.values()] + ) + + def _generate_ts_mapping( + self, observations: List[MultiAgentDict] + ) -> MultiAgentDict: + """Generates a timestep mapping to local agent timesteps. + + This helps us to keep track of which agent stepped at + which global (environment) timestep. + Note that the local (agent) timestep is given by the index + of the list for each agent. + + Args: + observations: A list of observations.Each observations maps agent + ids to their corresponding observation. + + Returns: A dictionary mapping agents to time index lists. The latter + contain the global (environment) timesteps at which the agent + stepped (was ready). + """ + # Only if agent ids have been provided we can build the timestep mapping. + if len(self._agent_ids) > 0: + global_t_to_local_t = {agent: _IndexMapping() for agent in self._agent_ids} + + # We need the observations to create the timestep mapping. + if len(observations) > 0: + for t, agent_map in enumerate(observations): + for agent_id in agent_map: + # If agent stepped add the timestep to the timestep mapping. + global_t_to_local_t[agent_id].append(t) + # Otherwise, set to an empoty dict (when creating an empty episode). + else: + global_t_to_local_t = {} + # Otherwise, set to an empoty dict (when creating an empty episode). + else: + # TODO (sven, simon): Shall we return an empty dict or an agent dict with + # empty lists? + global_t_to_local_t = {} + # Return the index mapping. + return global_t_to_local_t + + # TODO (simon): Add infos. + def _generate_single_agent_episode( + self, + agent_id: str, + agent_episode_ids: Optional[Dict[str, str]] = None, + observations: Optional[List[MultiAgentDict]] = None, + actions: Optional[List[MultiAgentDict]] = None, + rewards: Optional[List[MultiAgentDict]] = None, + infos: Optional[List[MultiAgentDict]] = None, + states: Optional[MultiAgentDict] = None, + extra_model_outputs: Optional[MultiAgentDict] = None, + ) -> SingleAgentEpisode: + """Generates a `SingleAgentEpisode` from multi-agent data. + + Note, if no data is provided an empty `SingleAgentEpiosde` + will be returned that starts at `SIngleAgentEpisode.t_started=0`. + + Args: + agent_id: String, idnetifying the agent for which the data should + be extracted. + agent_episode_ids: Optional. A dictionary mapping agents to + corresponding episode ids. If `None` the `SingleAgentEpisode` + creates a hexadecimal code. + observations: Optional. A list of dictionaries, each mapping + from agent ids to observations. When data is provided + it should be complete, i.e. observations, actions, rewards, + etc. should be provided. + actions: Optional. A list of dictionaries, each mapping + from agent ids to actions. When data is provided + it should be complete, i.e. observations, actions, rewards, + etc. should be provided. + rewards: Optional. A list of dictionaries, each mapping + from agent ids to rewards. When data is provided + it should be complete, i.e. observations, actions, rewards, + etc. should be provided. + infos: Optional. A list of dictionaries, each mapping + from agent ids to infos. When data is provided + it should be complete, i.e. observations, actions, rewards, + etc. should be provided. + states: Optional. A dicitionary mapping each agent to it's + module's hidden model state (if the model is stateful). + extra_model_outputs: Optional. A list of agent mappings for every + timestep. Each of these dictionaries maps an agent to its + corresponding `extra_model_outputs`, which a re specific model + outputs needed by the algorithm used (e.g. `vf_preds` and + `action_logp` for PPO). f data is provided it should be complete + (i.e. observations, actions, rewards, is_terminated, is_truncated, + and all necessary `extra_model_outputs`). + + Returns: An instance of `SingleAgentEpisode` containing the agent's + extracted episode data. + """ + + # If an episode id for an agent episode was provided assign it. + episode_id = None if agent_episode_ids is None else agent_episode_ids[agent_id] + # We need the timestep mapping to create single agent's episode. + if len(self.global_t_to_local_t) > 0: + # Set to None if not provided. + agent_observations = ( + None + if observations is None + else self._get_single_agent_data(agent_id, observations) + ) + + # Note, the timestep mapping is deduced from observations and starts one + # timestep earlier. Therefore all other data is missing the last index. + agent_actions = ( + None + if actions is None + else self._get_single_agent_data( + agent_id, actions, start_index=1, shift=-1 + ) + ) + agent_rewards = ( + None + if rewards is None + else self._get_single_agent_data( + agent_id, rewards, start_index=1, shift=-1 + ) + ) + # Like observations, infos start at timestep `t=0`, so we do not need to + # shift. + agent_infos = ( + None + if infos is None + else self._get_single_agent_data(agent_id, infos, start_index=1) + ) + agent_states = ( + None + if states is None + else self._get_single_agent_data( + agent_id, states, start_index=1, shift=-1 + ) + ) + agent_extra_model_outputs = ( + None + if extra_model_outputs is None + else self._get_single_agent_data( + agent_id, extra_model_outputs, start_index=1, shift=-1 + ) + ) + + return SingleAgentEpisode( + id_=episode_id, + observations=agent_observations, + actions=agent_actions, + rewards=agent_rewards, + infos=agent_infos, + states=agent_states, + extra_model_outputs=agent_extra_model_outputs, + ) + # Otherwise return empty ' SingleAgentEpisosde'. + else: + return SingleAgentEpisode(id_=episode_id) + + def _getattr_by_index( + self, + attr: str = "observations", + indices: Union[int, List[int]] = -1, + global_ts: bool = True, + ) -> MultiAgentDict: + # First for global_ts = True: + if global_ts: + # Check, if the indices are iterable. + if isinstance(indices, list): + indices = [self.t + (idx if idx < 0 else idx) for idx in indices] + else: + indices = [self.t + indices] if indices < 0 else [indices] + + return { + agent_id: list( + map( + getattr(agent_eps, attr).__getitem__, + self.global_t_to_local_t[agent_id].find_indices(indices), + ) + ) + for agent_id, agent_eps in self.agent_episodes.items() + # Only include agent data for agents that stepped. + if len(self.global_t_to_local_t[agent_id].find_indices(indices)) > 0 + } + # Otherwise just look for the timesteps in the `SingleAgentEpisode`s + # directly. + else: + # Check, if the indices are iterable. + if not isinstance(indices, list): + indices = [indices] + + return { + agent_id: list(map(getattr(agent_eps, attr).__getitem__, indices)) + for agent_id, agent_eps in self.agent_episodes.items() + # Only include agent data for agents that stepped so far at least once. + # TODO (sven, simon): This will not include initial observations. Should + # we? + if self.agent_episodes[agent_id].t > 0 + } + + def _get_single_agent_data( + self, + agent_id: str, + ma_data: List[MultiAgentDict], + start_index: int = 0, + end_index: Optional[int] = None, + shift: int = 0, + ) -> List[Any]: + """Returns single agent data from multi-agent data. + + Args: + agent_id: A string identifying the agent for which the + data should be extracted. + ma_data: A List of dictionaries, each containing multi-agent + data, i.e. mapping from agent ids to timestep data. + start_index: An integer defining the start point of the + extration window. The default starts at the beginning of the + the `ma_data` list. + end_index: Optional. An integer defining the end point of the + extraction window. If `None`, the extraction window will be + until the end of the `ma_data` list.g + shift: An integer that defines by which amount to shift the + running index for extraction. This is for example needed + when we extract data that started at index 1. + + Returns: A list containing single-agent data from the multi-agent + data provided. + """ + # Return all single agent data along the global timestep. + return [ + singleton[agent_id] + for singleton in list( + map( + ma_data.__getitem__, + [ + i + shift + for i in self.global_t_to_local_t[agent_id][ + start_index:end_index + ] + ], + ) + ) + if agent_id in singleton.keys() + ] + + def __len__(self): + """Returns the length of an `MultiAgentEpisode`. + + Note that the length of an episode is defined by the difference + between its actual timestep and the starting point. + + Returns: An integer defining the length of the episode or an + error if the episode has not yet started. + """ + assert self.t_started < self.t, ( + "ERROR: Cannot determine length of episode that hasn't started, yet!" + "Call `MultiAgentEpisode.add_initial_observation(initial_observation=)` " + "first (after which `len(MultiAgentEpisode)` will be 0)." + ) + return self.t - self.t_started + + +class _IndexMapping(list): + """Provides lists with a method to find multiple elements. + + This class is used for the timestep mapping which is central to + the multi-agent episode. For each agent the timestep mapping is + implemented with an `IndexMapping`. + + The `IndexMapping.find_indices` method simplifies the search for + multiple environment timesteps at which some agents have stepped. + See for example `MultiAgentEpisode.get_observations()`. + """ + + def find_indices(self, indices_to_find: List[int]): + """Returns global timesteps at which an agent stepped. + + The function returns for a given list of indices the ones + that are stored in the `IndexMapping`. + + Args: + indices_to_find: A list of indices that should be + found in the `IndexMapping`. + + Returns: + A list of indices at which to find the `indices_to_find` + in `self`. This could be empty if none of the given + indices are in `IndexMapping`. + """ + indices = [] + for num in indices_to_find: + if num in self: + indices.append(self.index(num)) + return indices diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index a8696fff2a4b..0146e114bbc2 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -24,7 +24,7 @@ # TODO (sven): This gives a tricky circular import that goes # deep into the library. We have to see, where to dissolve it. - from ray.rllib.utils.replay_buffers.episode_replay_buffer import _Episode as Episode + from ray.rllib.env.single_agent_episode import SingleAgentEpisode _, tf, _ = try_import_tf() torch, nn = try_import_torch() @@ -95,9 +95,11 @@ def __init__(self, config: "AlgorithmConfig", **kwargs): # This should be the default. self._needs_initial_reset: bool = True - self._episodes: List[Optional["Episode"]] = [None for _ in range(self.num_envs)] + self._episodes: List[Optional["SingleAgentEpisode"]] = [ + None for _ in range(self.num_envs) + ] - self._done_episodes_for_metrics: List["Episode"] = [] + self._done_episodes_for_metrics: List["SingleAgentEpisode"] = [] self._ongoing_episodes_for_metrics: Dict[List] = defaultdict(list) self._ts_since_last_metrics: int = 0 self._weights_seq_no: int = 0 @@ -111,7 +113,7 @@ def sample( explore: bool = True, random_actions: bool = False, with_render_data: bool = False, - ) -> List["Episode"]: + ) -> List["SingleAgentEpisode"]: """Runs and returns a sample (n timesteps or m episodes) on the env(s).""" # If not execution details are provided, use the config. @@ -149,16 +151,14 @@ def _sample_timesteps( explore: bool = True, random_actions: bool = False, force_reset: bool = False, - ) -> List["Episode"]: + ) -> List["SingleAgentEpisode"]: """Helper method to sample n timesteps.""" # TODO (sven): This gives a tricky circular import that goes # deep into the library. We have to see, where to dissolve it. - from ray.rllib.utils.replay_buffers.episode_replay_buffer import ( - _Episode as Episode, - ) + from ray.rllib.env.single_agent_episode import SingleAgentEpisode - done_episodes_to_return: List["Episode"] = [] + done_episodes_to_return: List["SingleAgentEpisode"] = [] # Get initial states for all 'batch_size_B` rows in the forward batch, # i.e. for all vector sub_envs. @@ -174,7 +174,7 @@ def _sample_timesteps( if force_reset or self._needs_initial_reset: obs, infos = self.env.reset() - self._episodes = [Episode() for _ in range(self.num_envs)] + self._episodes = [SingleAgentEpisode() for _ in range(self.num_envs)] states = initial_states # Set initial obs and states in the episodes. @@ -284,7 +284,7 @@ def _sample_timesteps( done_episodes_to_return.append(self._episodes[i]) # Create a new episode object. - self._episodes[i] = Episode( + self._episodes[i] = SingleAgentEpisode( observations=[obs[i]], infos=[infos[i]], states=s ) else: @@ -320,7 +320,7 @@ def _sample_episodes( explore: bool = True, random_actions: bool = False, with_render_data: bool = False, - ) -> List["Episode"]: + ) -> List["SingleAgentEpisode"]: """Helper method to run n episodes. See docstring of `self.sample()` for more details. @@ -328,14 +328,12 @@ def _sample_episodes( # TODO (sven): This gives a tricky circular import that goes # deep into the library. We have to see, where to dissolve it. - from ray.rllib.utils.replay_buffers.episode_replay_buffer import ( - _Episode as Episode, - ) + from ray.rllib.env.single_agent_episode import SingleAgentEpisode - done_episodes_to_return: List["Episode"] = [] + done_episodes_to_return: List["SingleAgentEpisode"] = [] obs, infos = self.env.reset() - episodes = [Episode() for _ in range(self.num_envs)] + episodes = [SingleAgentEpisode() for _ in range(self.num_envs)] # Multiply states n times according to our vector env batch size (num_envs). states = tree.map_structure( @@ -431,7 +429,7 @@ def _sample_episodes( states[k][i] = (convert_to_numpy(v),) # Create a new episode object. - episodes[i] = Episode( + episodes[i] = SingleAgentEpisode( observations=[obs[i]], infos=[infos[i]], states=s, diff --git a/rllib/env/single_agent_episode.py b/rllib/env/single_agent_episode.py new file mode 100644 index 000000000000..1d8ea472ee0a --- /dev/null +++ b/rllib/env/single_agent_episode.py @@ -0,0 +1,528 @@ +import numpy as np +import uuid + +from gymnasium.core import ActType, ObsType +from typing import Any, Dict, List, Optional, SupportsFloat + +from ray.rllib.policy.sample_batch import SampleBatch + + +class SingleAgentEpisode: + def __init__( + self, + id_: Optional[str] = None, + *, + observations: List[ObsType] = None, + actions: List[ActType] = None, + rewards: List[SupportsFloat] = None, + infos: List[Dict] = None, + states=None, + t_started: Optional[int] = None, + is_terminated: bool = False, + is_truncated: bool = False, + render_images: Optional[List[np.ndarray]] = None, + extra_model_outputs: Optional[Dict[str, Any]] = None, + ) -> "SingleAgentEpisode": + """Initializes a `SingleAgentEpisode` instance. + + This constructor can be called with or without sampled data. Note + that if data is provided the episode will start at timestep + `t_started = len(observations) - 1` (the initial observation is not + counted). If the episode should start at `t_started = 0` (e.g. + because the instance should simply store episode data) this has to + be provided in the `t_started` parameter of the constructor. + + Args: + id_: Optional. Unique identifier for this episode. If no id is + provided the constructor generates a hexadecimal code for the id. + observations: Optional. A list of observations from a rollout. If + data is provided it should be complete (i.e. observations, actions, + rewards, is_terminated, is_truncated, and all necessary + `extra_model_outputs`). The length of the `observations` defines + the default starting value. See the parameter `t_started`. + actions: Optional. A list of actions from a rollout. If data is + provided it should be complete (i.e. observations, actions, + rewards, is_terminated, is_truncated, and all necessary + `extra_model_outputs`). + rewards: Optional. A list of rewards from a rollout. If data is + provided it should be complete (i.e. observations, actions, + rewards, is_terminated, is_truncated, and all necessary + `extra_model_outputs`). + infos: Optional. A list of infos from a rollout. If data is + provided it should be complete (i.e. observations, actions, + rewards, is_terminated, is_truncated, and all necessary + `extra_model_outputs`). + states: Optional. The hidden model states from a rollout. If + data is provided it should be complete (i.e. observations, actions, + rewards, is_terminated, is_truncated, and all necessary + `extra_model_outputs`). States are only avasilable if a stateful + model (`RLModule`) is used. + t_started: Optional. The starting timestep of the episode. The default + is zero. If data is provided, the starting point is from the last + observation onwards (i.e. `t_started = len(observations) - 1). If + this parameter is provided the episode starts at the provided value. + is_terminated: Optional. A boolean indicating, if the episode is already + terminated. Note, this parameter is only needed, if episode data is + provided in the constructor. The default is `False`. + is_truncated: Optional. A boolean indicating, if the episode was + truncated. Note, this parameter is only needed, if episode data is + provided in the constructor. The default is `False`. + render_images: Optional. A list of RGB uint8 images from rendering + the environment. + extra_model_outputs: Optional. A list of dictionaries containing specific + model outputs for the algorithm used (e.g. `vf_preds` and `action_logp` + for PPO) from a rollout. If data is provided it should be complete + (i.e. observations, actions, rewards, is_terminated, is_truncated, + and all necessary `extra_model_outputs`). + """ + self.id_ = id_ or uuid.uuid4().hex + # Observations: t0 (initial obs) to T. + self.observations = [] if observations is None else observations + # Actions: t1 to T. + self.actions = [] if actions is None else actions + # Rewards: t1 to T. + self.rewards = [] if rewards is None else rewards + # Infos: t0 (initial info) to T. + if infos is None: + self.infos = [{} for _ in range(len(self.observations))] + else: + self.infos = infos + # h-states: t0 (in case this episode is a continuation chunk, we need to know + # about the initial h) to T. + self.states = states + # The global last timestep of the episode and the timesteps when this chunk + # started. + # TODO (simon): Check again what are the consequences of this decision for + # the methods of this class. For example the `validate()` method or + # `create_successor`. Write a test. + # Note, the case `t_started > len(observations) - 1` can occur, if a user + # wants to have an episode that is ongoing but does not want to carry the + # stale data from the last rollout in it. + self.t = self.t_started = ( + t_started if t_started is not None else max(len(self.observations) - 1, 0) + ) + if self.t_started < len(self.observations) - 1: + self.t = len(self.observations) - 1 + + # obs[-1] is the final observation in the episode. + self.is_terminated = is_terminated + # obs[-1] is the last obs in a truncated-by-the-env episode (there will no more + # observations in following chunks for this episode). + self.is_truncated = is_truncated + # RGB uint8 images from rendering the env; the images include the corresponding + # rewards. + assert render_images is None or observations is not None + self.render_images = [] if render_images is None else render_images + # Extra model outputs, e.g. `action_dist_input` needed in the batch. + self.extra_model_outputs = ( + {} if extra_model_outputs is None else extra_model_outputs + ) + + def concat_episode(self, episode_chunk: "SingleAgentEpisode"): + """Adds the given `episode_chunk` to the right side of self. + + Args: + episode_chunk: Another `SingleAgentEpisode` to be concatenated. + + Returns: A `SingleAegntEpisode` instance containing the concatenated + from both episodes. + """ + assert episode_chunk.id_ == self.id_ + assert not self.is_done + # Make sure the timesteps match. + assert self.t == episode_chunk.t_started + + episode_chunk.validate() + + # Make sure, end matches other episode chunk's beginning. + assert np.all(episode_chunk.observations[0] == self.observations[-1]) + # Pop out our last observations and infos (as these are identical + # to the first obs and infos in the next episode). + self.observations.pop() + self.infos.pop() + + # Extend ourselves. In case, episode_chunk is already terminated (and numpyfied) + # we need to convert to lists (as we are ourselves still filling up lists). + self.observations.extend(list(episode_chunk.observations)) + self.actions.extend(list(episode_chunk.actions)) + self.rewards.extend(list(episode_chunk.rewards)) + self.infos.extend(list(episode_chunk.infos)) + self.t = episode_chunk.t + self.states = episode_chunk.states + + if episode_chunk.is_terminated: + self.is_terminated = True + elif episode_chunk.is_truncated: + self.is_truncated = True + + for k, v in episode_chunk.extra_model_outputs.items(): + self.extra_model_outputs[k].extend(list(v)) + + # Validate. + self.validate() + + def add_initial_observation( + self, + *, + initial_observation: ObsType, + initial_info: Optional[Dict] = None, + initial_state=None, + initial_render_image: Optional[np.ndarray] = None, + ) -> None: + """Adds the initial data to the episode. + + Args: + initial_observation: Obligatory. The initial observation. + initial_info: Optional. The initial info. + initial_state: Optional. The initial hidden state of a + model (`RLModule`) if the latter is stateful. + initial_render_image: Optional. An RGB uint8 image from rendering + the environment. + """ + assert not self.is_done + assert len(self.observations) == 0 + # Assume that this episode is completely empty and has not stepped yet. + # Leave self.t (and self.t_started) at 0. + assert self.t == self.t_started == 0 + + initial_info = initial_info or {} + + self.observations.append(initial_observation) + self.states = initial_state + self.infos.append(initial_info) + if initial_render_image is not None: + self.render_images.append(initial_render_image) + # TODO (sven): Do we have to call validate here? It is our own function + # that manipulates the object. + self.validate() + + def add_timestep( + self, + observation: ObsType, + action: ActType, + reward: SupportsFloat, + *, + info: Optional[Dict[str, Any]] = None, + state=None, + is_terminated: bool = False, + is_truncated: bool = False, + render_image: Optional[np.ndarray] = None, + extra_model_output: Optional[Dict[str, Any]] = None, + ) -> None: + """Adds a timestep to the episode. + + Args: + observation: The observation received from the + environment. + action: The last action used by the agent. + reward: The last reward received by the agent. + info: The last info recevied from the environment. + state: Optional. The last hidden state of the model (`RLModule` ). + This is only available, if the model is stateful. + is_terminated: A boolean indicating, if the environment has been + terminated. + is_truncated: A boolean indicating, if the environment has been + truncated. + render_image: Optional. An RGB uint8 image from rendering + the environment. + extra_model_output: The last timestep's specific model outputs + (e.g. `vf_preds` for PPO). + """ + # Cannot add data to an already done episode. + assert not self.is_done + + self.observations.append(observation) + self.actions.append(action) + self.rewards.append(reward) + info = info or {} + self.infos.append(info) + self.states = state + self.t += 1 + if render_image is not None: + self.render_images.append(render_image) + if extra_model_output is not None: + for k, v in extra_model_output.items(): + if k not in self.extra_model_outputs: + self.extra_model_outputs[k] = [v] + else: + self.extra_model_outputs[k].append(v) + self.is_terminated = is_terminated + self.is_truncated = is_truncated + self.validate() + + def validate(self) -> None: + """Validates the episode. + + This function ensures that the data stored to a `SingleAgentEpisode` is + in order (e.g. that the correct number of observations, actions, rewards + are there). + """ + # Make sure we always have one more obs stored than rewards (and actions) + # due to the reset and last-obs logic of an MDP. + assert ( + len(self.observations) + == len(self.infos) + == len(self.rewards) + 1 + == len(self.actions) + 1 + ) + # TODO (sven): This is unclear to me. It makes sense + # to start at a point after the length of experiences + # provided at initialization, but when we test then here + # it will imo always error out. + # Example: we initialize the class by providing 101 observations, + # 100 actions and rewards. + # self.t = self.t_started = len(observations) - 1. Then + # we add a single timestep. self.t += 1 and + # self.t - self.t_started is 1, but len(rewards) is 100. + assert len(self.rewards) == (self.t - self.t_started) + + if len(self.extra_model_outputs) > 0: + for k, v in self.extra_model_outputs.items(): + assert len(v) == len(self.observations) - 1 + + # Convert all lists to numpy arrays, if we are terminated. + if self.is_done: + self.convert_lists_to_numpy() + + @property + def is_done(self) -> bool: + """Whether the episode is actually done (terminated or truncated). + + A done episode cannot be continued via `self.add_timestep()` or being + concatenated on its right-side with another episode chunk or being + succeeded via `self.create_successor()`. + """ + return self.is_terminated or self.is_truncated + + def convert_lists_to_numpy(self) -> None: + """Converts list attributes to numpy arrays. + + When an episode is terminated or truncated (`self.is_done`) the data + will be not anymore touched and instead converted to numpy for later + use in postprocessing. This function converts all the data stored + into numpy arrays. + """ + + self.observations = np.array(self.observations) + self.actions = np.array(self.actions) + self.rewards = np.array(self.rewards) + self.infos = np.array(self.infos) + self.render_images = np.array(self.render_images, dtype=np.uint8) + for k, v in self.extra_model_outputs.items(): + self.extra_model_outputs[k] = np.array(v) + + def create_successor(self) -> "SingleAgentEpisode": + """Returns a successor episode chunk (of len=0) continuing with this one. + + The successor will have the same ID and state as self and its only observation + will be the last observation in self. Its length will therefore be 0 (no + steps taken yet). + + This method is useful if you would like to discontinue building an episode + chunk (b/c you have to return it from somewhere), but would like to have a new + episode (chunk) instance to continue building the actual env episode at a later + time. + + Returns: + The successor Episode chunk of this one with the same ID and state and the + only observation being the last observation in self. + """ + assert not self.is_done + + return SingleAgentEpisode( + # Same ID. + id_=self.id_, + # First (and only) observation of successor is this episode's last obs. + observations=[self.observations[-1]], + # First (and only) info of successor is this episode's last info. + infos=[self.infos[-1]], + # Same state. + states=self.states, + # Continue with self's current timestep. + t_started=self.t, + ) + + def to_sample_batch(self) -> SampleBatch: + """Converts a `SingleAgentEpisode` into a `SampleBatch`. + + Note that `RLlib` is relying in training on the `SampleBatch` class and + therefore episodes have to be converted to this format before training can + start. + + Returns: + An `ray.rLlib.policy.sample_batch.SampleBatch` instance containing this + episode's data. + """ + return SampleBatch( + { + SampleBatch.EPS_ID: np.array([self.id_] * len(self)), + SampleBatch.OBS: self.observations[:-1], + SampleBatch.NEXT_OBS: self.observations[1:], + SampleBatch.ACTIONS: self.actions, + SampleBatch.REWARDS: self.rewards, + SampleBatch.T: list(range(self.t_started, self.t)), + SampleBatch.TERMINATEDS: np.array( + [False] * (len(self) - 1) + [self.is_terminated] + ), + SampleBatch.TRUNCATEDS: np.array( + [False] * (len(self) - 1) + [self.is_truncated] + ), + # Return the infos after stepping the environment. + SampleBatch.INFOS: self.infos[1:], + **self.extra_model_outputs, + } + ) + + @staticmethod + def from_sample_batch(batch: SampleBatch) -> "SingleAgentEpisode": + """Converts a `SampleBatch` instance into a `SingleAegntEpisode`. + + The `ray.rllib.policy.sample_batch.SampleBatch` class is used in `RLlib` + for training an agent's modules (`RLModule`), converting from or to + `SampleBatch` can be performed by this function and its counterpart + `to_sample_batch()`. + + Args: + batch: A `SampleBatch` instance. It should contain only a single episode. + + Returns: + An `SingleAegntEpisode` instance containing the data from `batch`. + """ + is_done = ( + batch[SampleBatch.TERMINATEDS][-1] or batch[SampleBatch.TRUNCATEDS][-1] + ) + observations = np.concatenate( + [batch[SampleBatch.OBS], batch[SampleBatch.NEXT_OBS][None, -1]] + ) + actions = batch[SampleBatch.ACTIONS] + rewards = batch[SampleBatch.REWARDS] + # These are the infos after stepping the environment, i.e. without the + # initial info. + infos = batch[SampleBatch.INFOS] + # Concatenate an intiial empty info. + infos = np.concatenate([np.array([{}]), infos]) + + # TODO (simon): This is very ugly, but right now + # we can only do it according to the exclusion principle. + extra_model_output_keys = [] + for k in batch.keys(): + if k not in [ + SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, + SampleBatch.ENV_ID, + SampleBatch.AGENT_INDEX, + SampleBatch.T, + SampleBatch.SEQ_LENS, + SampleBatch.OBS, + SampleBatch.INFOS, + SampleBatch.NEXT_OBS, + SampleBatch.ACTIONS, + SampleBatch.PREV_ACTIONS, + SampleBatch.REWARDS, + SampleBatch.PREV_REWARDS, + SampleBatch.TERMINATEDS, + SampleBatch.TRUNCATEDS, + SampleBatch.UNROLL_ID, + SampleBatch.DONES, + SampleBatch.CUR_OBS, + ]: + extra_model_output_keys.append(k) + + return SingleAgentEpisode( + id_=batch[SampleBatch.EPS_ID][0], + observations=observations if is_done else observations.tolist(), + actions=actions if is_done else actions.tolist(), + rewards=rewards if is_done else rewards.tolist(), + t_started=batch[SampleBatch.T][0], + is_terminated=batch[SampleBatch.TERMINATEDS][-1], + is_truncated=batch[SampleBatch.TRUNCATEDS][-1], + infos=infos if is_done else infos.tolist(), + extra_model_outputs={ + k: (batch[k] if is_done else batch[k].tolist()) + for k in extra_model_output_keys + }, + ) + + def get_return(self) -> float: + """Calculates an episode's return. + + The return is computed by a simple sum, neglecting the discount factor. + This is used predominantly for metrics. + + Returns: + The sum of rewards collected during this episode. + """ + return sum(self.rewards) + + def get_state(self) -> Dict[str, Any]: + """Returns the pickable state of an episode. + + The data in the episode is stored into a dictionary. Note that episodes + can also be generated from states (see `self.from_state()`). + + Returns: + A dictionary containing all the data from the episode. + """ + return list( + { + "id_": self.id_, + "observations": self.observations, + "actions": self.actions, + "rewards": self.rewards, + "infos": self.infos, + "states": self.states, + "t_started": self.t_started, + "t": self.t, + "is_terminated": self.is_terminated, + "is_truncated": self.is_truncated, + **self.extra_model_outputs, + }.items() + ) + + @staticmethod + def from_state(state: Dict[str, Any]) -> "SingleAgentEpisode": + """Generates a `SingleAegntEpisode` from a pickable state. + + The data in the state has to be complete. This is always the case when the state + was created by a `SingleAgentEpisode` itself calling `self.get_state()`. + + Args: + state: A dictionary containing all episode data. + + Returns: + A `SingleAgentEpisode` instance holding all the data provided by `state`. + """ + eps = SingleAgentEpisode(id_=state[0][1]) + eps.observations = state[1][1] + eps.actions = state[2][1] + eps.rewards = state[3][1] + eps.infos = state[4][1] + eps.states = state[5][1] + eps.t_started = state[6][1] + eps.t = state[7][1] + eps.is_terminated = state[8][1] + eps.is_truncated = state[9][1] + eps.extra_model_outputs = {k: v for k, v in state[10:]} + # Validate the episode to ensure complete data. + eps.validate() + return eps + + def __len__(self) -> int: + """Returning the length of an episode. + + The length of an episode is defined by the length of its data. This is the + number of timesteps an agent has stepped through an environment so far. + The length is undefined in case of a just started episode. + + Returns: + An integer, defining the length of an episode. + + Raises: + AssertionError: If episode has never been stepped so far. + """ + assert len(self.observations) > 0, ( + "ERROR: Cannot determine length of episode that hasn't started yet! Call " + "`SingleAgentEpisode.add_initial_observation(initial_observation=...)` " + "first (after which `len(SingleAgentEpisode)` will be 0)." + ) + return len(self.observations) - 1 diff --git a/rllib/env/testing/single_agent_gym_env_runner.py b/rllib/env/testing/single_agent_gym_env_runner.py index fefe24b84204..2bf885b94c54 100644 --- a/rllib/env/testing/single_agent_gym_env_runner.py +++ b/rllib/env/testing/single_agent_gym_env_runner.py @@ -4,8 +4,8 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.annotations import override -from ray.rllib.utils.replay_buffers.episode_replay_buffer import _Episode as Episode class SingleAgentGymEnvRunner(EnvRunner): @@ -42,7 +42,7 @@ def sample( num_episodes: Optional[int] = None, force_reset: bool = False, **kwargs, - ) -> Tuple[List[Episode], List[Episode]]: + ) -> Tuple[List[SingleAgentEpisode], List[SingleAgentEpisode]]: """Returns a tuple (list of completed episodes, list of ongoing episodes). Args: @@ -61,9 +61,9 @@ def sample( **kwargs: Forward compatibility kwargs. Returns: - A tuple consisting of: A list of Episode instances that are already - done (either terminated or truncated, hence their `is_done` property is - True), a list of Episode instances that are still ongoing + A tuple consisting of: A list of SingleAgentEpisode instances that are + already done (either terminated or truncated, hence their `is_done` property + is True), a list of SingleAgentEpisode instances that are still ongoing (their `is_done` property is False). """ assert not (num_timesteps is not None and num_episodes is not None) @@ -89,7 +89,7 @@ def _sample_timesteps( self, num_timesteps: int, force_reset: bool = False, - ) -> Tuple[List[Episode], List[Episode]]: + ) -> Tuple[List[SingleAgentEpisode], List[SingleAgentEpisode]]: """Runs n timesteps on the environment(s) and returns experiences. Timesteps are counted in total (across all vectorized sub-environments). For @@ -104,7 +104,7 @@ def _sample_timesteps( # Start new episodes. # Set initial observations of the new episodes. self._episodes = [ - Episode(observations=[o]) for o in self._split_by_env(obs) + SingleAgentEpisode(observations=[o]) for o in self._split_by_env(obs) ] self._needs_initial_reset = False @@ -146,7 +146,7 @@ def _sample_timesteps( # Add this finished episode to the list of completed ones. done_episodes_to_return.append(self._episodes[i]) # Start a new episode and set its initial observation to `o`. - self._episodes[i] = Episode(observations=[o]) + self._episodes[i] = SingleAgentEpisode(observations=[o]) # Episode is ongoing -> Add a timestep. else: self._episodes[i].add_timestep(o, a, r) @@ -156,7 +156,7 @@ def _sample_timesteps( ongoing_episodes = self._episodes # Create new chunks (using the same IDs and latest observations). self._episodes = [ - Episode(id_=eps.id_, observations=[eps.observations[-1]]) + SingleAgentEpisode(id_=eps.id_, observations=[eps.observations[-1]]) for eps in self._episodes ] # Return tuple: done episodes, ongoing ones. @@ -175,7 +175,9 @@ def _sample_episodes( done_episodes_to_return = [] obs, _ = self.env.reset() - episodes = [Episode(observations=[o]) for o in self._split_by_env(obs)] + episodes = [ + SingleAgentEpisode(observations=[o]) for o in self._split_by_env(obs) + ] eps = 0 while eps < num_episodes: @@ -216,7 +218,7 @@ def _sample_episodes( break # Start a new episode and set its initial observation to `o`. - episodes[i] = Episode(observations=[o]) + episodes[i] = SingleAgentEpisode(observations=[o]) else: episodes[i].add_timestep(o, a, r) diff --git a/rllib/env/tests/test_single_agent_episode.py b/rllib/env/tests/test_single_agent_episode.py new file mode 100644 index 000000000000..dc549e4f5589 --- /dev/null +++ b/rllib/env/tests/test_single_agent_episode.py @@ -0,0 +1,453 @@ +import gymnasium as gym +import numpy as np +import unittest + +from gymnasium.core import ActType, ObsType +from typing import Any, Dict, Optional, SupportsFloat, Tuple + +import ray +from ray.rllib.env.single_agent_episode import SingleAgentEpisode + +# TODO (simon): Add to the tests `info` and `extra_model_outputs` +# as soon as #39732 is merged. + + +class TestEnv(gym.Env): + def __init__(self): + self.observation_space = gym.spaces.Discrete(201) + self.action_space = gym.spaces.Discrete(200) + self.t = 0 + + def reset( + self, *, seed: Optional[int] = None, options=Optional[Dict[str, Any]] + ) -> Tuple[ObsType, Dict[str, Any]]: + self.t = 0 + return 0, {} + + def step( + self, action: ActType + ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + self.t += 1 + if self.t == 200: + is_terminated = True + else: + is_terminated = False + + return self.t, self.t, is_terminated, False, {} + + +class TestSingelAgentEpisode(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_init(self): + """Tests initialization of `SingleAgentEpisode`. + + Three cases are tested: + 1. Empty episode with default starting timestep. + 2. Empty episode starting at `t_started=10`. This is only interesting + for ongoing episodes, where we do not want to carry on the stale + entries from the last rollout. + 3. Initialization with pre-collected data. + """ + # Create empty episode. + episode = SingleAgentEpisode() + # Empty episode should have a start point and count of zero. + self.assertTrue(episode.t_started == episode.t == 0) + + # Create an episode with a specific starting point. + episode = SingleAgentEpisode(t_started=10) + self.assertTrue(episode.t == episode.t_started == 10) + + # Sample 100 values and initialize episode with observations and infos. + env = gym.make("CartPole-v1") + # Initialize containers. + observations = [] + rewards = [] + actions = [] + infos = [] + extra_model_outputs = [] + states = np.random.random(10) + + # Initialize observation and info. + init_obs, init_info = env.reset() + observations.append(init_obs) + infos.append(init_info) + # Run 100 samples. + for _ in range(100): + action = env.action_space.sample() + obs, reward, is_terminated, is_truncated, info = env.step(action) + observations.append(obs) + actions.append(action) + rewards.append(reward) + infos.append(info) + extra_model_outputs.append({"extra_1": np.random.random()}) + + # Build the episode. + episode = SingleAgentEpisode( + observations=observations, + actions=actions, + rewards=rewards, + infos=infos, + states=states, + is_terminated=is_terminated, + is_truncated=is_truncated, + extra_model_outputs=extra_model_outputs, + ) + # The starting point and count should now be at `len(observations) - 1`. + self.assertTrue(episode.t == episode.t_started == (len(observations) - 1)) + + def test_add_initial_observation(self): + """Tests adding initial observations and infos. + + This test ensures that when initial observation and info are provided + the length of the lists are correct and the timestep is still at zero, + as the agent has not stepped, yet. + """ + # Create empty episode. + episode = SingleAgentEpisode() + # Create environment. + env = gym.make("CartPole-v1") + + # Add initial observations. + obs, info = env.reset() + episode.add_initial_observation(initial_observation=obs, initial_info=info) + + # Assert that the observations are added to their list. + self.assertTrue(len(episode.observations) == 1) + # Assert that the infos are added to their list. + self.assertTrue(len(episode.infos) == 1) + # Assert that the timesteps are still at zero as we have not stepped, yet. + self.assertTrue(episode.t == episode.t_started == 0) + + def test_add_timestep(self): + """Tests if adding timestep data to a `SingleAgentEpisode` works. + + Adding timestep data is the central part of collecting episode + dara. Here it is tested if adding to the internal data lists + works as intended and the timestep is increased during each step. + """ + # Create an empty episode and add initial observations. + episode = SingleAgentEpisode() + env = gym.make("CartPole-v1") + # Set the random seed (otherwise the episode will terminate at + # different points in each test run). + obs, info = env.reset(seed=0) + episode.add_initial_observation(initial_observation=obs, initial_info=info) + + # Sample 100 timesteps and add them to the episode. + for i in range(100): + action = env.action_space.sample() + obs, reward, is_terminated, is_truncated, info = env.step(action) + episode.add_timestep( + observation=obs, + action=action, + reward=reward, + info=info, + is_terminated=is_terminated, + is_truncated=is_truncated, + extra_model_output={"extra": np.random.random(1)}, + ) + if is_terminated or is_truncated: + break + + # Assert that the episode timestep is at 100. + self.assertTrue(episode.t == len(episode.observations) - 1 == i + 1) + # Assert that `t_started` stayed at zero. + self.assertTrue(episode.t_started == 0) + # Assert that all lists have the proper lengths. + self.assertTrue( + len(episode.actions) + == len(episode.rewards) + == len(episode.observations) - 1 + == len(episode.infos) - 1 + == i + 1 + ) + # Assert that the flags are set correctly. + self.assertTrue(episode.is_terminated == is_terminated) + self.assertTrue(episode.is_truncated == is_truncated) + self.assertTrue(episode.is_done == is_terminated or is_truncated) + + def test_create_successor(self): + """Tests creation of a scucessor of a `SingleAgentEpisode`. + + This test makes sure that when creating a successor the successor's + data is coherent with the episode that should be succeeded. + Observation and info are available before each timestep; therefore + these data is carried over to the successor. + """ + + # Create an empty episode. + episode_1 = SingleAgentEpisode() + # Create an environment. + env = TestEnv() + # Add initial observation. + init_obs, init_info = env.reset() + episode_1.add_initial_observation( + initial_observation=init_obs, initial_info=init_info + ) + # Sample 100 steps. + for i in range(100): + action = i + obs, reward, is_terminated, is_truncated, info = env.step(action) + episode_1.add_timestep( + observation=obs, + action=action, + reward=reward, + info=info, + is_terminated=is_terminated, + is_truncated=is_truncated, + extra_model_output={"extra": np.random.random(1)}, + ) + + # Assert that the episode has indeed 100 timesteps. + self.assertTrue(episode_1.t == 100) + + # Create a successor. + episode_2 = episode_1.create_successor() + # Assert that it has the same id. + self.assertTrue(episode_1.id_ == episode_2.id_) + # Assert that the timestep starts at the end of the last episode. + self.assertTrue(episode_1.t == episode_2.t == episode_2.t_started) + # Assert that the last observation of `episode_1` is the first of + # `episode_2`. + self.assertTrue(episode_1.observations[-1] == episode_2.observations[0]) + # Assert that the last info of `episode_1` is the first of episode_2`. + self.assertTrue(episode_1.infos[-1] == episode_2.infos[0]) + + # Test immutability. + action = 100 + obs, reward, is_terminated, is_truncated, info = env.step(action) + episode_2.add_timestep( + observation=obs, + action=action, + reward=reward, + info=info, + is_terminated=is_terminated, + is_truncated=is_truncated, + extra_model_output={"extra": np.random.random(1)}, + ) + # Assert that this does not change also the predecessor's data. + self.assertFalse(len(episode_1.observations) == len(episode_2.observations)) + + def test_concat_episode(self): + """Tests if concatenation of two `SingleAgentEpisode`s works. + + This test ensures that concatenation of two episodes work. Note that + concatenation should only work for two chunks of the same episode, i.e. + they have the same `id_` and one should be the successor of the other. + It is also tested that concatenation fails, if timesteps do not match or + the episode to which we want to concatenate is already terminated. + """ + # Create two episodes and fill them with 100 timesteps each. + env = TestEnv() + init_obs, init_info = env.reset() + episode_1 = SingleAgentEpisode() + episode_1.add_initial_observation( + initial_observation=init_obs, initial_info=init_info + ) + # Sample 100 timesteps. + for i in range(100): + action = i + obs, reward, is_terminated, is_truncated, info = env.step(action) + episode_1.add_timestep( + observation=obs, + action=action, + reward=reward, + info=info, + is_terminated=is_terminated, + is_truncated=is_truncated, + extra_model_output={"extra": np.random.random(1)}, + ) + + # Create a successor. + episode_2 = episode_1.create_successor() + + # Now, sample 100 more timesteps. + for i in range(100, 200): + action = i + obs, reward, is_terminated, is_truncated, info = env.step(action) + episode_2.add_timestep( + observation=obs, + action=action, + reward=reward, + info=info, + is_terminated=is_terminated, + is_truncated=is_truncated, + extra_model_output={"extra": np.random.random(1)}, + ) + + # Assert that the second episode's `t_started` is at the first episode's + # `t`. + self.assertTrue(episode_1.t == episode_2.t_started) + # Assert that the second episode's `t` is at 200. + self.assertTrue(episode_2.t == 200) + + # Manipulate the id of the second episode and make sure an error is + # thrown during concatenation. + episode_2.id_ = "wrong" + with self.assertRaises(AssertionError): + episode_1.concat_episode(episode_2) + # Reset the id. + episode_2.id_ = episode_1.id_ + # Assert that when timesteps do not match an error is thrown. + episode_2.t += 1 + with self.assertRaises(AssertionError): + episode_1.concat_episode(episode_2) + # Reset the timestep. + episode_2.t -= 1 + # Assert that when the first episode is already done no concatenation can take + # place. + episode_1.is_terminated = True + with self.assertRaises(AssertionError): + episode_1.concat_episode(episode_2) + # Reset `is_terminated`. + episode_1.is_terminated = False + + # Concate the episodes. + + episode_1.concat_episode(episode_2) + # Assert that the concatenated episode start at `t_started=0` + # and has 200 sampled steps, i.e. `t=200`. + self.assertTrue(episode_1.t_started == 0) + self.assertTrue(episode_1.t == 200) + # Assert that all lists have appropriate length. + self.assertTrue( + len(episode_1.actions) + == len(episode_1.rewards) + == len(episode_1.observations) - 1 + == len(episode_1.infos) - 1 + == 200 + ) + # Assert that specific observations in the two episodes match. + self.assertEqual(episode_2.observations[5], episode_1.observations[105]) + # Assert that they are not the same object. + # TODO (sven): Do we really need a deepcopy here? + # self.assertNotEqual(id(episode_2.observations[5]), + # id(episode_1.observations[105])) + + def test_get_and_from_state(self): + """Tests, if a `SingleAgentEpisode` can be reconstructed form state. + + This test constructs an episode, stores it to its dictionary state and + recreates a new episode form this state. Thereby it ensures that all + atttributes are indeed identical to the primer episode and the data is + complete. + """ + # Create an empty episode. + episode = SingleAgentEpisode() + # Create an environment. + env = TestEnv() + # Add initial observation. + init_obs, init_info = env.reset() + episode.add_initial_observation( + initial_observation=init_obs, initial_info=init_info + ) + # Sample 100 steps. + for i in range(100): + action = i + obs, reward, is_terminated, is_truncated, info = env.step(action) + episode.add_timestep( + observation=obs, + action=action, + reward=reward, + info=info, + is_terminated=is_terminated, + is_truncated=is_truncated, + extra_model_output={"extra": np.random.random(1)}, + ) + + # Get the state and reproduce it from state. + state = episode.get_state() + episode_reproduced = SingleAgentEpisode.from_state(state) + + # Assert that the data is complete. + self.assertEqual(episode.id_, episode_reproduced.id_) + self.assertEqual(episode.t, episode_reproduced.t) + self.assertEqual(episode.t_started, episode_reproduced.t_started) + self.assertEqual(episode.is_terminated, episode_reproduced.is_terminated) + self.assertEqual(episode.is_truncated, episode_reproduced.is_truncated) + self.assertListEqual(episode.observations, episode_reproduced.observations) + self.assertListEqual(episode.actions, episode_reproduced.actions) + self.assertListEqual(episode.rewards, episode_reproduced.rewards) + self.assertListEqual(episode.infos, episode_reproduced.infos) + self.assertEqual(episode.is_terminated, episode_reproduced.is_terminated) + self.assertEqual(episode.is_truncated, episode_reproduced.is_truncated) + self.assertEqual(episode.states, episode_reproduced.states) + self.assertListEqual(episode.render_images, episode_reproduced.render_images) + self.assertDictEqual( + episode.extra_model_outputs, episode_reproduced.extra_model_outputs + ) + + # Assert that reconstruction breaks, if the data is not complete. + state[1][1].pop() + with self.assertRaises(AssertionError): + episode_reproduced = SingleAgentEpisode.from_state(state) + + def test_to_and_from_sample_batch(self): + """Tests if a `SingelAgentEpisode` can be reconstructed from a `SampleBatch`. + + This tests converst an episode to a `SampleBatch` and reconstructs the + episode then from this sample batch. It is then tested, if all data is + complete. + Note that `extra_model_outputs` are defined by the user and as the format + in the episode from which a `SampleBatch` was created is unknown this + reconstruction would only work, if the user does take care of it (as a + counter example just rempve the index [0] from the `extra_model_output`). + """ + # Create an empty episode. + episode = SingleAgentEpisode() + # Create an environment. + env = TestEnv() + # Add initial observation. + init_obs, init_obs = env.reset() + episode.add_initial_observation( + initial_observation=init_obs, initial_info=init_obs + ) + # Sample 100 steps. + for i in range(100): + action = i + obs, reward, is_terminated, is_truncated, info = env.step(action) + episode.add_timestep( + observation=obs, + action=action, + reward=reward, + info=info, + is_terminated=is_terminated, + is_truncated=is_truncated, + extra_model_output={"extra": np.random.random(1)[0]}, + ) + + # Create `SampleBatch`. + batch = episode.to_sample_batch() + # Reproduce form `SampleBatch`. + episode_reproduced = SingleAgentEpisode.from_sample_batch(batch) + # Assert that the data is complete. + self.assertEqual(episode.id_, episode_reproduced.id_) + self.assertEqual(episode.t, episode_reproduced.t) + self.assertEqual(episode.t_started, episode_reproduced.t_started) + self.assertEqual(episode.is_terminated, episode_reproduced.is_terminated) + self.assertEqual(episode.is_truncated, episode_reproduced.is_truncated) + self.assertListEqual(episode.observations, episode_reproduced.observations) + self.assertListEqual(episode.actions, episode_reproduced.actions) + self.assertListEqual(episode.rewards, episode_reproduced.rewards) + self.assertEqual(episode.infos, episode_reproduced.infos) + self.assertEqual(episode.is_terminated, episode_reproduced.is_terminated) + self.assertEqual(episode.is_truncated, episode_reproduced.is_truncated) + self.assertEqual(episode.states, episode_reproduced.states) + self.assertListEqual(episode.render_images, episode_reproduced.render_images) + self.assertDictEqual( + episode.extra_model_outputs, episode_reproduced.extra_model_outputs + ) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/evaluation/postprocessing_v2.py b/rllib/evaluation/postprocessing_v2.py index c68a2ba5f295..9fae6c1ce325 100644 --- a/rllib/evaluation/postprocessing_v2.py +++ b/rllib/evaluation/postprocessing_v2.py @@ -12,7 +12,7 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.replay_buffers.episode_replay_buffer import _Episode +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.torch_utils import convert_to_torch_tensor from ray.rllib.utils.typing import TensorType @@ -28,7 +28,9 @@ class Postprocessing: @DeveloperAPI -def postprocess_episodes_to_sample_batch(episodes: List[_Episode]) -> SampleBatch: +def postprocess_episodes_to_sample_batch( + episodes: List[SingleAgentEpisode], +) -> SampleBatch: """Converts the results from sampling with an `EnvRunner` to one `SampleBatch'. Once the `SampleBatch` will be deprecated this function will be @@ -58,7 +60,7 @@ def postprocess_episodes_to_sample_batch(episodes: List[_Episode]) -> SampleBatc @DeveloperAPI def compute_gae_for_episode( - episode: _Episode, + episode: SingleAgentEpisode, config: AlgorithmConfig, module: RLModule, ): @@ -90,7 +92,9 @@ def compute_gae_for_episode( return episode -def compute_bootstrap_value(episode: _Episode, module: RLModule) -> _Episode: +def compute_bootstrap_value( + episode: SingleAgentEpisode, module: RLModule +) -> SingleAgentEpisode: if episode.is_terminated: last_r = 0.0 else: @@ -142,7 +146,7 @@ def compute_bootstrap_value(episode: _Episode, module: RLModule) -> _Episode: def compute_advantages( - episode: _Episode, + episode: SingleAgentEpisode, last_r: float, gamma: float = 0.9, lambda_: float = 1.0, diff --git a/rllib/utils/replay_buffers/episode_replay_buffer.py b/rllib/utils/replay_buffers/episode_replay_buffer.py index 484ebc4645e9..4c09df8111bc 100644 --- a/rllib/utils/replay_buffers/episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/episode_replay_buffer.py @@ -1,11 +1,10 @@ from collections import deque import copy from typing import Any, Dict, List, Optional, Union -import uuid import numpy as np -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.annotations import override from ray.rllib.utils.replay_buffers.base import ReplayBufferInterface from ray.rllib.utils.typing import SampleBatchType @@ -101,12 +100,12 @@ def __len__(self) -> int: return self.get_num_timesteps() @override(ReplayBufferInterface) - def add(self, episodes: Union[List["_Episode"], "_Episode"]): - """Converts the incoming SampleBatch into a number of _Episode objects. + def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]): + """Converts the incoming SampleBatch into a number of SingleAgentEpisode objects. Then adds these episodes to the internal deque. """ - if isinstance(episodes, _Episode): + if isinstance(episodes, SingleAgentEpisode): episodes = [episodes] for eps in episodes: @@ -350,7 +349,7 @@ def get_state(self) -> Dict[str, Any]: @override(ReplayBufferInterface) def set_state(self, state) -> None: self.episodes = deque( - [_Episode.from_state(eps_data) for eps_data in state["episodes"]] + [SingleAgentEpisode.from_state(eps_data) for eps_data in state["episodes"]] ) self.episode_id_to_index = dict(state["episode_id_to_index"]) self._num_episodes_evicted = state["_num_episodes_evicted"] @@ -358,318 +357,3 @@ def set_state(self, state) -> None: self._num_timesteps = state["_num_timesteps"] self._num_timesteps_added = state["_num_timesteps_added"] self.sampled_timesteps = state["sampled_timesteps"] - - -# TODO (sven): Make this EpisodeV3 - replacing EpisodeV2 - to reduce all the -# information leakage we currently have in EpisodeV2 (policy_map, worker, etc.. are -# all currently held by EpisodeV2 for no good reason). -class _Episode: - def __init__( - self, - id_: Optional[str] = None, - *, - observations=None, - actions=None, - rewards=None, - infos=None, - states=None, - t: int = 0, - is_terminated: bool = False, - is_truncated: bool = False, - render_images=None, - extra_model_outputs=None, - ): - self.id_ = id_ or uuid.uuid4().hex - # Observations: t0 (initial obs) to T. - self.observations = [] if observations is None else observations - # Actions: t1 to T. - self.actions = [] if actions is None else actions - # Rewards: t1 to T. - self.rewards = [] if rewards is None else rewards - # Infos: t0 (initial info) to T. - self.infos = [] if infos is None else infos - # h-states: t0 (in case this episode is a continuation chunk, we need to know - # about the initial h) to T. - self.states = states - # The global last timestep of the episode and the timesteps when this chunk - # started. - self.t = self.t_started = t - # obs[-1] is the final observation in the episode. - self.is_terminated = is_terminated - # obs[-1] is the last obs in a truncated-by-the-env episode (there will no more - # observations in following chunks for this episode). - self.is_truncated = is_truncated - # RGB uint8 images from rendering the env; the images include the corresponding - # rewards. - assert render_images is None or observations is not None - self.render_images = [] if render_images is None else render_images - # Extra model outputs, e.g. `action_dist_input` needed in the batch. - self.extra_model_outputs = ( - {} if extra_model_outputs is None else extra_model_outputs - ) - - def concat_episode(self, episode_chunk: "_Episode"): - """Adds the given `episode_chunk` to the right side of self.""" - assert episode_chunk.id_ == self.id_ - assert not self.is_done - # Make sure the timesteps match. - assert self.t == episode_chunk.t_started - - episode_chunk.validate() - - # Make sure, end matches other episode chunk's beginning. - assert np.all(episode_chunk.observations[0] == self.observations[-1]) - # Make sure the timesteps match (our last t should be the same as their first). - assert self.t == episode_chunk.t_started - # Pop out our last observations and infos (as these are identical - # to the first obs and infos in the next episode). - self.observations.pop() - self.infos.pop() - - # Extend ourselves. In case, episode_chunk is already terminated (and numpyfied) - # we need to convert to lists (as we are ourselves still filling up lists). - self.observations.extend(list(episode_chunk.observations)) - self.actions.extend(list(episode_chunk.actions)) - self.rewards.extend(list(episode_chunk.rewards)) - self.infos.extend(list(episode_chunk.infos)) - self.t = episode_chunk.t - self.states = episode_chunk.states - - if episode_chunk.is_terminated: - self.is_terminated = True - elif episode_chunk.is_truncated: - self.is_truncated = True - - for k, v in episode_chunk.extra_model_outputs.items(): - self.extra_model_outputs[k].extend(list(v)) - # Validate. - self.validate() - - def add_initial_observation( - self, - *, - initial_observation, - initial_info=None, - initial_state=None, - initial_render_image=None, - ): - assert not self.is_done - assert len(self.observations) == 0 - # Assume that this episode is completely empty and has not stepped yet. - # Leave self.t (and self.t_started) at 0. - assert self.t == self.t_started == 0 - - initial_info = initial_info or {} - - self.observations.append(initial_observation) - self.infos.append(initial_info) - self.states = initial_state - if initial_render_image is not None: - self.render_images.append(initial_render_image) - self.validate() - - def add_timestep( - self, - observation, - action, - reward, - *, - info=None, - state=None, - is_terminated=False, - is_truncated=False, - render_image=None, - extra_model_output=None, - ): - # Cannot add data to an already done episode. - assert not self.is_done - - info = info or {} - - self.observations.append(observation) - self.actions.append(action) - self.rewards.append(reward) - self.infos.append(info) - self.states = state - self.t += 1 - if render_image is not None: - self.render_images.append(render_image) - if extra_model_output is not None: - for k, v in extra_model_output.items(): - if k not in self.extra_model_outputs: - self.extra_model_outputs[k] = [v] - else: - self.extra_model_outputs[k].append(v) - self.is_terminated = is_terminated - self.is_truncated = is_truncated - self.validate() - - def validate(self): - # Make sure we always have one more obs stored than rewards (and actions) - # due to the reset and last-obs logic of an MDP. - assert len(self.observations) == len(self.rewards) + 1 == len(self.actions) + 1 - assert len(self.rewards) == (self.t - self.t_started) - # Convert all lists to numpy arrays, if we are terminated. - if self.is_done: - self.convert_lists_to_numpy() - - @property - def is_done(self): - """Whether the episode is actually done (terminated or truncated). - - A done episode cannot be continued via `self.add_timestep()` or being - concatenated on its right-side with another episode chunk or being - succeeded via `self.create_successor()`. - """ - return self.is_terminated or self.is_truncated - - def convert_lists_to_numpy(self): - """Converts list attributes to numpy arrays.""" - - self.observations = np.array(self.observations) - self.actions = np.array(self.actions) - self.rewards = np.array(self.rewards) - self.infos = np.array(self.infos) - self.render_images = np.array(self.render_images, dtype=np.uint8) - for k, v in self.extra_model_outputs.items(): - self.extra_model_outputs[k] = np.array(v) - - def create_successor(self) -> "_Episode": - """Returns a successor episode chunk (of len=0) continuing with this one. - - The successor will have the same ID and state as self and its only observation - will be the last observation in self. Its length will therefore be 0 (no - steps taken yet). - - This method is useful if you would like to discontinue building an episode - chunk (b/c you have to return it from somewhere), but would like to have a new - episode (chunk) instance to continue building the actual env episode at a later - time. - - Returns: - The successor Episode chunk of this one with the same ID and state and the - only observation being the last observation in self. - """ - assert not self.is_done - - return _Episode( - # Same ID. - id_=self.id_, - # First (and only) observation of successor is this episode's last obs. - observations=[self.observations[-1]], - # In addition, first (and only) info of successor is the episode's last - # info. - infos=[self.infos[-1]], - # Same state. - states=self.states, - # Continue with self's current timestep. - t=self.t, - ) - - def to_sample_batch(self): - """Converts an episode to a SampleBatch object.""" - - # If the episode is not done, yet, we need to convert - # to arrays first. - if not self.is_done: - self.convert_lists_to_numpy() - - # Return the sample batch. - return SampleBatch( - { - SampleBatch.EPS_ID: np.array([self.id_] * len(self)), - SampleBatch.OBS: self.observations[:-1], - SampleBatch.NEXT_OBS: self.observations[1:], - SampleBatch.ACTIONS: self.actions, - SampleBatch.REWARDS: self.rewards, - SampleBatch.TERMINATEDS: np.array( - [False] * (len(self) - 1) + [self.is_terminated] - ), - SampleBatch.TRUNCATEDS: np.array( - [False] * (len(self) - 1) + [self.is_truncated] - ), - SampleBatch.INFOS: self.infos[:-1], - **self.extra_model_outputs, - } - ) - - @staticmethod - def from_sample_batch(batch): - # TODO (simon): This is very ugly, but right now - # we can only do it according to the exclusion principle. - extra_model_output_keys = [] - for k in batch.keys(): - if k not in [ - SampleBatch.EPS_ID, - SampleBatch.AGENT_INDEX, - SampleBatch.ENV_ID, - SampleBatch.AGENT_INDEX, - SampleBatch.T, - SampleBatch.SEQ_LENS, - SampleBatch.OBS, - SampleBatch.NEXT_OBS, - SampleBatch.ACTIONS, - SampleBatch.PREV_ACTIONS, - SampleBatch.REWARDS, - SampleBatch.PREV_REWARDS, - SampleBatch.TERMINATEDS, - SampleBatch.TRUNCATEDS, - SampleBatch.UNROLL_ID, - SampleBatch.DONES, - SampleBatch.CUR_OBS, - ]: - extra_model_output_keys.append(k) - return _Episode( - id_=batch[SampleBatch.EPS_ID][0], - observations=np.concatenate( - [batch[SampleBatch.OBS], batch[SampleBatch.NEXT_OBS][None, -1]] - ), - actions=batch[SampleBatch.ACTIONS], - rewards=batch[SampleBatch.REWARDS], - is_terminated=batch[SampleBatch.TERMINATEDS][-1], - is_truncated=batch[SampleBatch.TRUNCATEDS][-1], - infos=batch[SampleBatch.INFOS], - extra_model_outputs={k: batch[k] for k in extra_model_output_keys}, - ) - - def get_return(self): - return sum(self.rewards) - - def get_state(self): - return list( - { - "id_": self.id_, - "observations": self.observations, - "actions": self.actions, - "rewards": self.rewards, - "states": self.states, - "t_started": self.t_started, - "t": self.t, - "is_terminated": self.is_terminated, - "is_truncated": self.is_truncated, - **self.extra_model_outputs, - }.items() - ) - - @staticmethod - def from_state(state): - eps = _Episode(id_=state[0][1]) - eps.observations = state[1][1] - eps.actions = state[2][1] - eps.rewards = state[3][1] - eps.infos = state[4][1] - eps.states = state[5][1] - eps.t_started = state[6][1] - eps.t = state[7][1] - eps.is_terminated = state[8][1] - eps.is_truncated = state[9][1] - eps.extra_model_outputs = {*state[10:][1]} - return eps - - def __len__(self): - assert len(self.observations) > 0, ( - "ERROR: Cannot determine length of episode that hasn't started yet! " - "Call `_Episode.add_initial_observation(initial_observation=...)` first " - "(after which `len(_Episode)` will be 0)." - ) - return len(self.observations) - 1 diff --git a/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py index ee2712ca4c1f..95662eb14cbf 100644 --- a/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py @@ -1,9 +1,8 @@ import unittest import numpy as np - +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.replay_buffers.episode_replay_buffer import ( - _Episode, EpisodeReplayBuffer, ) @@ -11,13 +10,14 @@ class TestEpisodeReplayBuffer(unittest.TestCase): @staticmethod def _get_episode(episode_len=None, id_=None): - eps = _Episode(id_=id_, observations=[0.0]) + eps = SingleAgentEpisode(id_=id_, observations=[0.0], infos=[{}]) ts = np.random.randint(1, 200) if episode_len is None else episode_len for t in range(ts): eps.add_timestep( observation=float(t + 1), action=int(t), reward=0.1 * (t + 1), + info={}, ) eps.is_terminated = np.random.random() > 0.5 eps.is_truncated = False if eps.is_terminated else np.random.random() > 0.8