diff --git a/benchmarl/algorithms/ippo.py b/benchmarl/algorithms/ippo.py index 45e7e7c8..1ddf6172 100644 --- a/benchmarl/algorithms/ippo.py +++ b/benchmarl/algorithms/ippo.py @@ -49,6 +49,7 @@ def __init__( lmbda: float, scale_mapping: str, use_tanh_normal: bool, + pre_compute_advantage: bool, **kwargs ): super().__init__(**kwargs) @@ -61,6 +62,7 @@ def __init__( self.lmbda = lmbda self.scale_mapping = scale_mapping self.use_tanh_normal = use_tanh_normal + self.pre_compute_advantage = pre_compute_advantage ############################# # Overridden abstract methods @@ -148,15 +150,17 @@ def _get_policy_for_loss( spec=self.action_spec[group, "action"], in_keys=[(group, "loc"), (group, "scale")], out_keys=[(group, "action")], - distribution_class=IndependentNormal - if not self.use_tanh_normal - else TanhNormal, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_normal - else {}, + distribution_class=( + IndependentNormal if not self.use_tanh_normal else TanhNormal + ), + distribution_kwargs=( + { + "min": self.action_spec[(group, "action")].space.low, + "max": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_normal + else {} + ), return_log_prob=True, log_prob_key=(group, "log_prob"), ) @@ -220,14 +224,14 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: nested_reward_key, batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)), ) - - with torch.no_grad(): - loss = self.get_loss_and_updater(group)[0] - loss.value_estimator( - batch, - params=loss.critic_network_params, - target_params=loss.target_critic_network_params, - ) + if self.pre_compute_advantage: + with torch.no_grad(): + loss = self.get_loss_and_updater(group)[0] + loss.value_estimator( + batch, + params=loss.critic_network_params, + target_params=loss.target_critic_network_params, + ) return batch @@ -285,6 +289,7 @@ class IppoConfig(AlgorithmConfig): lmbda: float = MISSING scale_mapping: str = MISSING use_tanh_normal: bool = MISSING + pre_compute_advantage: bool = MISSING @staticmethod def associated_class() -> Type[Algorithm]: diff --git a/benchmarl/algorithms/mappo.py b/benchmarl/algorithms/mappo.py index 0a391d2e..d3099693 100644 --- a/benchmarl/algorithms/mappo.py +++ b/benchmarl/algorithms/mappo.py @@ -53,6 +53,7 @@ def __init__( lmbda: float, scale_mapping: str, use_tanh_normal: bool, + pre_compute_advantage: bool, **kwargs ): super().__init__(**kwargs) @@ -65,6 +66,7 @@ def __init__( self.lmbda = lmbda self.scale_mapping = scale_mapping self.use_tanh_normal = use_tanh_normal + self.pre_compute_advantage = pre_compute_advantage ############################# # Overridden abstract methods @@ -152,15 +154,17 @@ def _get_policy_for_loss( spec=self.action_spec[group, "action"], in_keys=[(group, "loc"), (group, "scale")], out_keys=[(group, "action")], - distribution_class=IndependentNormal - if not self.use_tanh_normal - else TanhNormal, - distribution_kwargs={ - "min": self.action_spec[(group, "action")].space.low, - "max": self.action_spec[(group, "action")].space.high, - } - if self.use_tanh_normal - else {}, + distribution_class=( + IndependentNormal if not self.use_tanh_normal else TanhNormal + ), + distribution_kwargs=( + { + "min": self.action_spec[(group, "action")].space.low, + "max": self.action_spec[(group, "action")].space.high, + } + if self.use_tanh_normal + else {} + ), return_log_prob=True, log_prob_key=(group, "log_prob"), ) @@ -224,14 +228,14 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase: nested_reward_key, batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)), ) - - with torch.no_grad(): - loss = self.get_loss_and_updater(group)[0] - loss.value_estimator( - batch, - params=loss.critic_network_params, - target_params=loss.target_critic_network_params, - ) + if self.pre_compute_advantage: + with torch.no_grad(): + loss = self.get_loss_and_updater(group)[0] + loss.value_estimator( + batch, + params=loss.critic_network_params, + target_params=loss.target_critic_network_params, + ) return batch @@ -321,6 +325,7 @@ class MappoConfig(AlgorithmConfig): lmbda: float = MISSING scale_mapping: str = MISSING use_tanh_normal: bool = MISSING + pre_compute_advantage: bool = MISSING @staticmethod def associated_class() -> Type[Algorithm]: diff --git a/benchmarl/conf/algorithm/ippo.yaml b/benchmarl/conf/algorithm/ippo.yaml index 2cda60df..2e4c13a9 100644 --- a/benchmarl/conf/algorithm/ippo.yaml +++ b/benchmarl/conf/algorithm/ippo.yaml @@ -11,3 +11,4 @@ loss_critic_type: "l2" lmbda: 0.9 scale_mapping: "biased_softplus_1.0" use_tanh_normal: True +pre_compute_advantage: False diff --git a/benchmarl/conf/algorithm/mappo.yaml b/benchmarl/conf/algorithm/mappo.yaml index db194d5f..20da8dc6 100644 --- a/benchmarl/conf/algorithm/mappo.yaml +++ b/benchmarl/conf/algorithm/mappo.yaml @@ -12,3 +12,4 @@ loss_critic_type: "l2" lmbda: 0.9 scale_mapping: "biased_softplus_1.0" use_tanh_normal: True +pre_compute_advantage: False