Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Dec 17, 2024
1 parent bafbe86 commit a9db46b
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 64 deletions.
42 changes: 1 addition & 41 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,13 @@
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
Categorical,
Composite,
LazyTensorStorage,
OneHot,
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
from torchrl.envs import (
Compose,
EnvBase,
InitTracker,
TensorDictPrimer,
Transform,
TransformedEnv,
)
from torchrl.envs import Compose, EnvBase, Transform
from torchrl.objectives import LossModule
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater

Expand Down Expand Up @@ -251,38 +243,6 @@ def process_env_fun(
Returns: a function that takes no args and creates an enviornment
"""
if self.has_rnn:

def model_fun():
env = env_fun()

spec_actor = self.model_config.get_model_state_spec()
spec_actor = Composite(
{
group: Composite(
spec_actor.expand(len(agents), *spec_actor.shape),
shape=(len(agents),),
)
for group, agents in self.group_map.items()
}
)

env = TransformedEnv(
env,
Compose(
*(
[InitTracker(init_key="is_init")]
+ (
[TensorDictPrimer(spec_actor, reset_key="_reset")]
if len(spec_actor.keys(True, True)) > 0
else []
)
)
),
)
return env

return model_fun

return env_fun

Expand Down
100 changes: 81 additions & 19 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,23 @@
from collections import deque, OrderedDict
from dataclasses import dataclass, MISSING
from pathlib import Path
from typing import Any, Dict, List, Optional

from typing import Any, Callable, Dict, List, Optional

import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictSequential

from torchrl.collectors import SyncDataCollector
from torchrl.envs import ParallelEnv, SerialEnv, TransformedEnv

from torchrl.data import Composite
from torchrl.envs import (
EnvBase,
InitTracker,
ParallelEnv,
SerialEnv,
TensorDictPrimer,
TransformedEnv,
)
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.record.loggers import generate_exp_name
Expand Down Expand Up @@ -361,6 +370,7 @@ def _setup(self):
self._setup_task()
self._setup_algorithm()
self._setup_collector()
self._setup_buffers()
self._setup_name()
self._setup_logger()
self._on_setup()
Expand Down Expand Up @@ -436,7 +446,21 @@ def _setup_task(self):
transforms_env = Compose(*transforms_env)
transforms_training = Compose(*transforms_training)

self.observation_spec = self.task.observation_spec(test_env)
self.info_spec = self.task.info_spec(test_env)
self.state_spec = self.task.state_spec(test_env)
self.action_mask_spec = self.task.action_mask_spec(test_env)
self.action_spec = self.task.action_spec(test_env)
self.group_map = self.task.group_map(test_env)
self.train_group_map = copy.deepcopy(self.group_map)
self.max_steps = self.task.max_steps(test_env)

if self.model_config.is_rnn:
test_env = self._add_rnn_transforms(lambda: test_env)()
env_func = self._add_rnn_transforms(env_func)

if test_env.batch_size == ():
# If the environment is not vectorized, we simulate vectorization using parallel or serial environments
env_class = (
SerialEnv if not self.config.parallel_collection else ParallelEnv
)
Expand All @@ -453,28 +477,12 @@ def _setup_task(self):
self.config.sampling_device
)

self.observation_spec = self.task.observation_spec(self.test_env)
self.info_spec = self.task.info_spec(self.test_env)
self.state_spec = self.task.state_spec(self.test_env)
self.action_mask_spec = self.task.action_mask_spec(self.test_env)
self.action_spec = self.task.action_spec(self.test_env)
self.group_map = self.task.group_map(self.test_env)
self.train_group_map = copy.deepcopy(self.group_map)
self.max_steps = self.task.max_steps(self.test_env)

def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)

self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)()
self.env_func = self.algorithm.process_env_fun(self.env_func)

self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
)
for group in self.group_map.keys()
}
self.losses = {
group: self.algorithm.get_loss_and_updater(group)[0]
for group in self.group_map.keys()
Expand Down Expand Up @@ -523,6 +531,15 @@ def _setup_collector(self):
)
self.rollout_env = self.env_func().to(self.config.sampling_device)

def _setup_buffers(self):
self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
transforms=self.task.get_replay_buffer_transforms(self.test_env, group),
)
for group in self.group_map.keys()
}

def _setup_name(self):
self.algorithm_name = self.algorithm_config.associated_class().__name__.lower()
self.model_name = self.model_config.associated_class().__name__.lower()
Expand Down Expand Up @@ -929,3 +946,48 @@ def _load_experiment(self) -> Experiment:
)
self.load_state_dict(loaded_dict)
return self

def _add_rnn_transforms(
self,
env_fun: Callable[[], EnvBase],
) -> Callable[[], EnvBase]:
"""
This function adds RNN specific transforms to the environment
Args:
env_fun (callable): a function that takes no args and creates an environment
Returns: a function that takes no args and creates an environment
"""

def model_fun():
env = env_fun()
group_map = self.task.group_map(env)
spec_actor = self.model_config.get_model_state_spec()
spec_actor = Composite(
{
group: Composite(
spec_actor.expand(len(agents), *spec_actor.shape),
shape=(len(agents),),
)
for group, agents in group_map.items()
}
)

out_env = TransformedEnv(
env,
Compose(
*(
[InitTracker(init_key="is_init")]
+ (
[TensorDictPrimer(spec_actor, reset_key="_reset")]
if len(spec_actor.keys(True, True)) > 0
else []
)
)
),
)
return out_env

return model_fun
19 changes: 15 additions & 4 deletions test/test_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@


import pytest

from benchmarl.algorithms import (
algorithm_config_registry,
IddpgConfig,
IppoConfig,
IqlConfig,
IsacConfig,
MaddpgConfig,
MappoConfig,
Expand Down Expand Up @@ -68,16 +70,13 @@ def test_all_algos(

@pytest.mark.parametrize("algo_config", [IppoConfig, MasacConfig])
@pytest.mark.parametrize("task", list(PettingZooTask))
@pytest.mark.parametrize("parallel_collection", [True, False])
def test_all_tasks(
self,
algo_config: AlgorithmConfig,
task: Task,
parallel_collection,
experiment_config,
mlp_sequence_config,
):
experiment_config.parallel_collection = parallel_collection
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
Expand Down Expand Up @@ -112,16 +111,19 @@ def test_gnn(
"algo_config", [IddpgConfig, MappoConfig, QmixConfig, MasacConfig]
)
@pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG])
@pytest.mark.parametrize("parallel_collection", [True, False])
def test_gru(
self,
algo_config: AlgorithmConfig,
task: Task,
parallel_collection: bool,
experiment_config,
gru_mlp_sequence_config,
):
algo_config = algo_config.get_from_yaml()
if algo_config.has_critic():
algo_config.share_param_critic = False
experiment_config.parallel_collection = parallel_collection
experiment_config.share_policy_params = False
task = task.get_from_yaml()
experiment = Experiment(
Expand Down Expand Up @@ -160,17 +162,26 @@ def test_lstm(
)
experiment.run()

@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("algo_config", [MappoConfig, IsacConfig, IqlConfig])
@pytest.mark.parametrize("prefer_continuous", [True, False])
@pytest.mark.parametrize("task", [PettingZooTask.SIMPLE_TAG])
@pytest.mark.parametrize("parallel_collection", [True, False])
def test_reloading_trainer(
self,
algo_config: AlgorithmConfig,
task: Task,
parallel_collection,
experiment_config,
mlp_sequence_config,
prefer_continuous,
):
# To not run the same test twice
if (prefer_continuous and not algo_config.supports_continuous_actions()) or (
not prefer_continuous and not algo_config.supports_discrete_actions()
):
pytest.skip()

experiment_config.parallel_collection = parallel_collection
experiment_config.prefer_continuous_actions = prefer_continuous
algo_config = algo_config.get_from_yaml()

Expand Down

0 comments on commit a9db46b

Please sign in to comment.