From f3c33cb63c735d9ac2b5ec686df9eb865b2fd8fc Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:43:23 +0200 Subject: [PATCH] [Refactor] Update values headed to deprecation (#118) --- benchmarl/algorithms/iddpg.py | 14 ++++++++------ benchmarl/algorithms/ippo.py | 4 ++-- benchmarl/algorithms/isac.py | 20 +++++++++++--------- benchmarl/algorithms/maddpg.py | 14 ++++++++------ benchmarl/algorithms/mappo.py | 4 ++-- benchmarl/algorithms/masac.py | 20 +++++++++++--------- benchmarl/experiment/experiment.py | 2 +- setup.py | 2 +- 8 files changed, 44 insertions(+), 36 deletions(-) diff --git a/benchmarl/algorithms/iddpg.py b/benchmarl/algorithms/iddpg.py index a114123b..334aa526 100644 --- a/benchmarl/algorithms/iddpg.py +++ b/benchmarl/algorithms/iddpg.py @@ -123,12 +123,14 @@ def _get_policy_for_loss( in_keys=[(group, "param")], out_keys=[(group, "action")], distribution_class=TanhDelta if self.use_tanh_mapping else Delta, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_mapping - else {}, + distribution_kwargs=( + { + "low": self.action_spec[(group, "action")].space.low, + "high": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_mapping + else {} + ), return_log_prob=False, safe=not self.use_tanh_mapping, ) diff --git a/benchmarl/algorithms/ippo.py b/benchmarl/algorithms/ippo.py index aac2cd88..012c6880 100644 --- a/benchmarl/algorithms/ippo.py +++ b/benchmarl/algorithms/ippo.py @@ -158,8 +158,8 @@ def _get_policy_for_loss( ), distribution_kwargs=( { - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, + "low": self.action_spec[(group, "action")].space.low, + "high": self.action_spec[(group, "action")].space.high, } if self.use_tanh_normal else {} diff --git a/benchmarl/algorithms/isac.py b/benchmarl/algorithms/isac.py index 74c29ea7..03024c68 100644 --- a/benchmarl/algorithms/isac.py +++ b/benchmarl/algorithms/isac.py @@ -199,15 +199,17 @@ def _get_policy_for_loss( spec=self.action_spec[group, "action"], in_keys=[(group, "loc"), (group, "scale")], out_keys=[(group, "action")], - distribution_class=IndependentNormal - if not self.use_tanh_normal - else TanhNormal, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_normal - else {}, + distribution_class=( + IndependentNormal if not self.use_tanh_normal else TanhNormal + ), + distribution_kwargs=( + { + "low": self.action_spec[(group, "action")].space.low, + "high": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_normal + else {} + ), return_log_prob=True, log_prob_key=(group, "log_prob"), ) diff --git a/benchmarl/algorithms/maddpg.py b/benchmarl/algorithms/maddpg.py index 1590f81f..c3ad1069 100644 --- a/benchmarl/algorithms/maddpg.py +++ b/benchmarl/algorithms/maddpg.py @@ -123,12 +123,14 @@ def _get_policy_for_loss( in_keys=[(group, "param")], out_keys=[(group, "action")], distribution_class=TanhDelta if self.use_tanh_mapping else Delta, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_mapping - else {}, + distribution_kwargs=( + { + "low": self.action_spec[(group, "action")].space.low, + "high": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_mapping + else {} + ), return_log_prob=False, safe=not self.use_tanh_mapping, ) diff --git a/benchmarl/algorithms/mappo.py b/benchmarl/algorithms/mappo.py index 891200ef..3ddd8d53 100644 --- a/benchmarl/algorithms/mappo.py +++ b/benchmarl/algorithms/mappo.py @@ -162,8 +162,8 @@ def _get_policy_for_loss( ), distribution_kwargs=( { - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, + "low": self.action_spec[(group, "action")].space.low, + "high": self.action_spec[(group, "action")].space.high, } if self.use_tanh_normal else {} diff --git a/benchmarl/algorithms/masac.py b/benchmarl/algorithms/masac.py index 358010ef..1991403e 100644 --- a/benchmarl/algorithms/masac.py +++ b/benchmarl/algorithms/masac.py @@ -199,15 +199,17 @@ def _get_policy_for_loss( spec=self.action_spec[group, "action"], in_keys=[(group, "loc"), (group, "scale")], out_keys=[(group, "action")], - distribution_class=IndependentNormal - if not self.use_tanh_normal - else TanhNormal, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_normal - else {}, + distribution_class=( + IndependentNormal if not self.use_tanh_normal else TanhNormal + ), + distribution_kwargs=( + { + "low": self.action_spec[(group, "action")].space.low, + "high": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_normal + else {} + ), return_log_prob=True, log_prob_key=(group, "log_prob"), ) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index b09a7abc..b8a294c8 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -758,7 +758,7 @@ def _grad_clip(self, optimizer: torch.optim.Optimizer) -> float: def _evaluation_loop(self): evaluation_start = time.time() with set_exploration_type( - ExplorationType.MODE + ExplorationType.DETERMINISTIC if self.config.evaluation_deterministic_actions else ExplorationType.RANDOM ): diff --git a/setup.py b/setup.py index 636f9166..41c7c868 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ def get_version(): url="https://github.com/facebookresearch/BenchMARL", author="Matteo Bettini", author_email="mb2389@cl.cam.ac.uk", - install_requires=["torchrl>=0.4.0", "tqdm", "hydra-core"], + install_requires=["torchrl>=0.5.0", "tqdm", "hydra-core"], extras_require={ "vmas": ["vmas>=1.3.4"], "pettingzoo": ["pettingzoo[all]>=1.24.3"],