From e910a837c2cedf2531ead63b0719d39516149d79 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:45:29 +0000 Subject: [PATCH] [Feature] Parallel collection (#152) * parallel collection * parallel collection * parallel collection * fixes * fixes * fixes * revert buffer update --- benchmarl/algorithms/common.py | 42 +------------- .../conf/experiment/base_experiment.yaml | 7 ++- benchmarl/environments/common.py | 2 +- benchmarl/environments/magent/common.py | 9 +-- benchmarl/environments/meltingpot/common.py | 6 +- benchmarl/environments/pettingzoo/common.py | 7 ++- benchmarl/environments/smacv2/common.py | 5 +- benchmarl/environments/vmas/common.py | 5 +- benchmarl/experiment/experiment.py | 42 +++++++++----- benchmarl/utils.py | 56 ++++++++++++++++++- test/conftest.py | 1 + test/test_meltingpot.py | 28 ++++++++++ test/test_pettingzoo.py | 16 +++++- test/test_task.py | 2 +- 14 files changed, 155 insertions(+), 73 deletions(-) diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index 87e92bbd..96fba1ab 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -13,21 +13,13 @@ from tensordict.nn import TensorDictModule, TensorDictSequential from torchrl.data import ( Categorical, - Composite, LazyTensorStorage, OneHot, ReplayBuffer, TensorDictReplayBuffer, ) from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement -from torchrl.envs import ( - Compose, - EnvBase, - InitTracker, - TensorDictPrimer, - Transform, - TransformedEnv, -) +from torchrl.envs import Compose, EnvBase, Transform from torchrl.objectives import LossModule from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater @@ -251,38 +243,6 @@ def process_env_fun( Returns: a function that takes no args and creates an enviornment """ - if self.has_rnn: - - def model_fun(): - env = env_fun() - - spec_actor = self.model_config.get_model_state_spec() - spec_actor = Composite( - { - group: Composite( - spec_actor.expand(len(agents), *spec_actor.shape), - shape=(len(agents),), - ) - for group, agents in self.group_map.items() - } - ) - - env = TransformedEnv( - env, - Compose( - *( - [InitTracker(init_key="is_init")] - + ( - [TensorDictPrimer(spec_actor, reset_key="_reset")] - if len(spec_actor.keys(True, True)) > 0 - else [] - ) - ) - ), - ) - return env - - return model_fun return env_fun diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index 05d244e6..014aae5f 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -15,6 +15,9 @@ share_policy_params: True prefer_continuous_actions: True # If False collection is done using a collector (under no grad). If True, collection is done with gradients. collect_with_grad: False +# In case of non-vectorized environments, weather to run collection of multiple processes +# If this is used, there will be n_envs_per_worker processes, collecting frames_per_batch/n_envs_per_worker frames each +parallel_collection: False # Discount factor gamma: 0.9 @@ -51,7 +54,7 @@ max_n_frames: 3_000_000 on_policy_collected_frames_per_batch: 6000 # Number of environments used for collection # If the environment is vectorized, this will be the number of batched environments. -# Otherwise batching will be simulated and each env will be run sequentially. +# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection. on_policy_n_envs_per_worker: 10 # This is the number of times collected_frames_per_batch will be split into minibatches and trained on_policy_n_minibatch_iters: 45 @@ -63,7 +66,7 @@ on_policy_minibatch_size: 400 off_policy_collected_frames_per_batch: 6000 # Number of environments used for collection # If the environment is vectorized, this will be the number of batched environments. -# Otherwise batching will be simulated and each env will be run sequentially. +# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection. off_policy_n_envs_per_worker: 10 # This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over. off_policy_n_optimizer_steps: 1000 diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py index 63ac150e..cccccdfb 100644 --- a/benchmarl/environments/common.py +++ b/benchmarl/environments/common.py @@ -34,7 +34,7 @@ def _type_check_task_config( else: if warn_on_missing_dataclass: warnings.warn( - "TaskConfig python dataclass not foud, task is being loaded without type checks" + "TaskConfig python dataclass not found, task is being loaded without type checks" ) return config diff --git a/benchmarl/environments/magent/common.py b/benchmarl/environments/magent/common.py index b8964ddd..08cae772 100644 --- a/benchmarl/environments/magent/common.py +++ b/benchmarl/environments/magent/common.py @@ -3,7 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - +import copy from typing import Callable, Dict, List, Optional from torchrl.data import Composite @@ -31,9 +31,10 @@ def get_env_fun( seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: + config = copy.deepcopy(self.config) return lambda: PettingZooWrapper( - env=self.__get_env(), + env=self.__get_env(config), return_state=True, seed=seed, done_on_any=False, @@ -41,7 +42,7 @@ def get_env_fun( device=device, ) - def __get_env(self) -> EnvBase: + def __get_env(self, config) -> EnvBase: try: from magent2.environments import ( adversarial_pursuit_v4, @@ -66,7 +67,7 @@ def __get_env(self) -> EnvBase: } if self.name not in envs: raise Exception(f"{self.name} is not an environment of MAgent2") - return envs[self.name].parallel_env(**self.config, render_mode="rgb_array") + return envs[self.name].parallel_env(**config, render_mode="rgb_array") def supports_continuous_actions(self) -> bool: return False diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py index 848ceaa3..57423b13 100644 --- a/benchmarl/environments/meltingpot/common.py +++ b/benchmarl/environments/meltingpot/common.py @@ -3,7 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - +import copy from typing import Callable, Dict, List, Optional import torch @@ -84,11 +84,13 @@ def get_env_fun( ) -> Callable[[], EnvBase]: from torchrl.envs.libs.meltingpot import MeltingpotEnv + config = copy.deepcopy(self.config) + return lambda: MeltingpotEnv( substrate=self.name.lower(), categorical_actions=True, device=device, - **self.config, + **config, ) def supports_continuous_actions(self) -> bool: diff --git a/benchmarl/environments/pettingzoo/common.py b/benchmarl/environments/pettingzoo/common.py index 1bb3cd15..f6078c99 100644 --- a/benchmarl/environments/pettingzoo/common.py +++ b/benchmarl/environments/pettingzoo/common.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. # +import copy from typing import Callable, Dict, List, Optional from torchrl.data import Composite @@ -35,9 +36,9 @@ def get_env_fun( seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: + config = copy.deepcopy(self.config) if self.supports_continuous_actions() and self.supports_discrete_actions(): - self.config.update({"continuous_actions": continuous_actions}) - + config.update({"continuous_actions": continuous_actions}) return lambda: PettingZooEnv( categorical_actions=True, device=device, @@ -45,7 +46,7 @@ def get_env_fun( parallel=True, return_state=self.has_state(), render_mode="rgb_array", - **self.config + **config ) def supports_continuous_actions(self) -> bool: diff --git a/benchmarl/environments/smacv2/common.py b/benchmarl/environments/smacv2/common.py index dc87f6b7..08b972a1 100644 --- a/benchmarl/environments/smacv2/common.py +++ b/benchmarl/environments/smacv2/common.py @@ -3,7 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - +import copy from typing import Callable, Dict, List, Optional import torch @@ -42,8 +42,9 @@ def get_env_fun( seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: + config = copy.deepcopy(self.config) return lambda: SMACv2Env( - categorical_actions=True, seed=seed, device=device, **self.config + categorical_actions=True, seed=seed, device=device, **config ) def supports_continuous_actions(self) -> bool: diff --git a/benchmarl/environments/vmas/common.py b/benchmarl/environments/vmas/common.py index 9c648045..dd77a249 100644 --- a/benchmarl/environments/vmas/common.py +++ b/benchmarl/environments/vmas/common.py @@ -3,7 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - +import copy from typing import Callable, Dict, List, Optional from torchrl.data import Composite @@ -52,6 +52,7 @@ def get_env_fun( seed: Optional[int], device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: + config = copy.deepcopy(self.config) return lambda: VmasEnv( scenario=self.name.lower(), num_envs=num_envs, @@ -60,7 +61,7 @@ def get_env_fun( device=device, categorical_actions=True, clamp_actions=True, - **self.config, + **config, ) def supports_continuous_actions(self) -> bool: diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index bd72a717..2ea0f92c 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -14,13 +14,15 @@ from collections import deque, OrderedDict from dataclasses import dataclass, MISSING from pathlib import Path + from typing import Any, Dict, List, Optional import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictSequential from torchrl.collectors import SyncDataCollector -from torchrl.envs import SerialEnv, TransformedEnv + +from torchrl.envs import ParallelEnv, SerialEnv, TransformedEnv from torchrl.envs.transforms import Compose from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.record.loggers import generate_exp_name @@ -34,7 +36,7 @@ from benchmarl.experiment.logger import Logger from benchmarl.models import GnnConfig, SequenceModelConfig from benchmarl.models.common import ModelConfig -from benchmarl.utils import _read_yaml_config, seed_everything +from benchmarl.utils import _add_rnn_transforms, _read_yaml_config, seed_everything _has_hydra = importlib.util.find_spec("hydra") is not None if _has_hydra: @@ -58,6 +60,7 @@ class ExperimentConfig: share_policy_params: bool = MISSING prefer_continuous_actions: bool = MISSING collect_with_grad: bool = MISSING + parallel_collection: bool = MISSING gamma: float = MISSING lr: float = MISSING @@ -430,20 +433,10 @@ def _setup_task(self): transforms_training = transforms_env + [ self.task.get_reward_sum_transform(test_env) ] - transforms_env = Compose(*transforms_env) transforms_training = Compose(*transforms_training) - if test_env.batch_size == (): - self.env_func = lambda: TransformedEnv( - SerialEnv(self.config.n_envs_per_worker(self.on_policy), env_func), - transforms_training.clone(), - ) - else: - self.env_func = lambda: TransformedEnv( - env_func(), transforms_training.clone() - ) - + # Initialize test env self.test_env = TransformedEnv(test_env, transforms_env.clone()).to( self.config.sampling_device ) @@ -457,6 +450,29 @@ def _setup_task(self): self.train_group_map = copy.deepcopy(self.group_map) self.max_steps = self.task.max_steps(self.test_env) + # Add rnn transforms here so they do not show in the benchmarl specs + if self.model_config.is_rnn: + self.test_env = _add_rnn_transforms( + lambda: self.test_env, self.group_map, self.model_config + )() + env_func = _add_rnn_transforms(env_func, self.group_map, self.model_config) + + # Initialize train env + if self.test_env.batch_size == (): + # If the environment is not vectorized, we simulate vectorization using parallel or serial environments + env_class = ( + SerialEnv if not self.config.parallel_collection else ParallelEnv + ) + self.env_func = lambda: TransformedEnv( + env_class(self.config.n_envs_per_worker(self.on_policy), env_func), + transforms_training.clone(), + ) + else: + # Otherwise it is already vectorized + self.env_func = lambda: TransformedEnv( + env_func(), transforms_training.clone() + ) + def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm(experiment=self) diff --git a/benchmarl/utils.py b/benchmarl/utils.py index d2d63ae6..efe36e27 100644 --- a/benchmarl/utils.py +++ b/benchmarl/utils.py @@ -6,10 +6,16 @@ import importlib import random -from typing import Any, Dict, Union +import typing +from typing import Any, Callable, Dict, List, Union import torch import yaml +from torchrl.data import Composite +from torchrl.envs import Compose, EnvBase, InitTracker, TensorDictPrimer, TransformedEnv + +if typing.TYPE_CHECKING: + from benchmarl.models import ModelConfig _has_numpy = importlib.util.find_spec("numpy") is not None @@ -53,3 +59,51 @@ def seed_everything(seed: int): import numpy numpy.random.seed(seed) + + +def _add_rnn_transforms( + env_fun: Callable[[], EnvBase], + group_map: Dict[str, List[str]], + model_config: "ModelConfig", +) -> Callable[[], EnvBase]: + """ + This function adds RNN specific transforms to the environment + + Args: + env_fun (callable): a function that takes no args and creates an environment + group_map (Dict[str,List[str]]): the group_map of the agents + model_config (ModelConfig): the model configuration + + Returns: a function that takes no args and creates an environment + + """ + + def model_fun(): + env = env_fun() + spec_actor = model_config.get_model_state_spec() + spec_actor = Composite( + { + group: Composite( + spec_actor.expand(len(agents), *spec_actor.shape), + shape=(len(agents),), + ) + for group, agents in group_map.items() + } + ) + + out_env = TransformedEnv( + env, + Compose( + *( + [InitTracker(init_key="is_init")] + + ( + [TensorDictPrimer(spec_actor, reset_key="_reset")] + if len(spec_actor.keys(True, True)) > 0 + else [] + ) + ) + ), + ) + return out_env + + return model_fun diff --git a/test/conftest.py b/test/conftest.py index 3f53416e..5ce3bde8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -29,6 +29,7 @@ def experiment_config(tmp_path) -> ExperimentConfig: experiment_config.on_policy_n_envs_per_worker = ( experiment_config.off_policy_n_envs_per_worker ) = 2 + experiment_config.parallel_collection = False experiment_config.off_policy_n_optimizer_steps = 2 experiment_config.off_policy_train_batch_size = 3 experiment_config.off_policy_memory_size = 200 diff --git a/test/test_meltingpot.py b/test/test_meltingpot.py index 2e7b20f9..07ecb2d4 100644 --- a/test/test_meltingpot.py +++ b/test/test_meltingpot.py @@ -10,6 +10,7 @@ from benchmarl.algorithms import ( algorithm_config_registry, IppoConfig, + MappoConfig, MasacConfig, QmixConfig, ) @@ -78,6 +79,33 @@ def test_all_tasks( ) experiment.run() + @pytest.mark.parametrize("algo_config", [MappoConfig]) + @pytest.mark.parametrize("task", [MeltingPotTask.COINS]) + @pytest.mark.parametrize("parallel_collection", [True, False]) + def test_lstm( + self, + algo_config: AlgorithmConfig, + task: Task, + parallel_collection: bool, + experiment_config, + cnn_lstm_sequence_config, + ): + algo_config = algo_config.get_from_yaml() + if algo_config.has_critic(): + algo_config.share_param_critic = False + experiment_config.parallel_collection = parallel_collection + experiment_config.share_policy_params = False + task = task.get_from_yaml() + experiment = Experiment( + algorithm_config=algo_config, + model_config=cnn_lstm_sequence_config, + critic_model_config=cnn_lstm_sequence_config, + seed=0, + config=experiment_config, + task=task, + ) + experiment.run() + @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) @pytest.mark.parametrize("task", [MeltingPotTask.COMMONS_HARVEST__OPEN]) def test_reloading_trainer( diff --git a/test/test_pettingzoo.py b/test/test_pettingzoo.py index a726c017..5ce637c5 100644 --- a/test/test_pettingzoo.py +++ b/test/test_pettingzoo.py @@ -6,10 +6,12 @@ import pytest + from benchmarl.algorithms import ( algorithm_config_registry, IddpgConfig, IppoConfig, + IqlConfig, IsacConfig, MaddpgConfig, MappoConfig, @@ -109,16 +111,19 @@ def test_gnn( "algo_config", [IddpgConfig, MappoConfig, QmixConfig, MasacConfig] ) @pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG]) + @pytest.mark.parametrize("parallel_collection", [True, False]) def test_gru( self, algo_config: AlgorithmConfig, task: Task, + parallel_collection: bool, experiment_config, gru_mlp_sequence_config, ): algo_config = algo_config.get_from_yaml() if algo_config.has_critic(): algo_config.share_param_critic = False + experiment_config.parallel_collection = parallel_collection experiment_config.share_policy_params = False task = task.get_from_yaml() experiment = Experiment( @@ -157,17 +162,26 @@ def test_lstm( ) experiment.run() - @pytest.mark.parametrize("algo_config", algorithm_config_registry.values()) + @pytest.mark.parametrize("algo_config", [MappoConfig, IsacConfig, IqlConfig]) @pytest.mark.parametrize("prefer_continuous", [True, False]) @pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG]) + @pytest.mark.parametrize("parallel_collection", [True, False]) def test_reloading_trainer( self, algo_config: AlgorithmConfig, task: Task, + parallel_collection, experiment_config, mlp_sequence_config, prefer_continuous, ): + # To not run the same test twice + if (prefer_continuous and not algo_config.supports_continuous_actions()) or ( + not prefer_continuous and not algo_config.supports_discrete_actions() + ): + pytest.skip() + + experiment_config.parallel_collection = parallel_collection experiment_config.prefer_continuous_actions = prefer_continuous algo_config = algo_config.get_from_yaml() diff --git a/test/test_task.py b/test/test_task.py index fc1660f6..e6c65c0b 100644 --- a/test/test_task.py +++ b/test/test_task.py @@ -34,7 +34,7 @@ def test_loading_tasks(task_name): task_name_hydra = cfg.hydra.runtime.choices.task assert task_name_hydra == task_name - warn_message = "TaskConfig python dataclass not foud, task is being loaded without type checks" + warn_message = "TaskConfig python dataclass not found, task is being loaded without type checks" with ( pytest.warns(match=warn_message)