diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py index 5a9a7536..63ac150e 100644 --- a/benchmarl/environments/common.py +++ b/benchmarl/environments/common.py @@ -282,13 +282,14 @@ def get_env_transforms(self, env: EnvBase) -> List[Transform]: """ return [] - def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]: + def get_replay_buffer_transforms(self, env: EnvBase, group: str) -> List[Transform]: """ - Returns a list of :class:`torchrl.envs.Transform` to be applied to the :class:`torchrl.data.ReplayBuffer`. + Returns a list of :class:`torchrl.envs.Transform` to be applied to the :class:`torchrl.data.ReplayBuffer` + of the specified group. Args: env (EnvBase): An environment created via self.get_env_fun - + group (str): The agent group using the replay buffer """ return [] diff --git a/benchmarl/environments/meltingpot/common.py b/benchmarl/environments/meltingpot/common.py index 82a87fbc..848ceaa3 100644 --- a/benchmarl/environments/meltingpot/common.py +++ b/benchmarl/environments/meltingpot/common.py @@ -125,22 +125,16 @@ def get_env_transforms(self, env: EnvBase) -> List[Transform]: else [] ) - def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]: + def get_replay_buffer_transforms(self, env: EnvBase, group: str) -> List[Transform]: return [ DTypeCastTransform( dtype_in=torch.uint8, dtype_out=torch.float, in_keys=[ "RGB", - *[ - (group, "observation", "RGB") - for group in self.group_map(env).keys() - ], + (group, "observation", "RGB"), ("next", "RGB"), - *[ - ("next", group, "observation", "RGB") - for group in self.group_map(env).keys() - ], + ("next", group, "observation", "RGB"), ], in_keys_inv=[], ) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index ff0c47cf..bd72a717 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -466,7 +466,7 @@ def _setup_algorithm(self): self.replay_buffers = { group: self.algorithm.get_replay_buffer( group=group, - transforms=self.task.get_replay_buffer_transforms(self.test_env), + transforms=self.task.get_replay_buffer_transforms(self.test_env, group), ) for group in self.group_map.keys() }