Skip to content

Commit

Permalink
disable pre compute advantage
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jun 12, 2024
1 parent a5c629b commit 8aea1a2
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 34 deletions.
39 changes: 22 additions & 17 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
lmbda: float,
scale_mapping: str,
use_tanh_normal: bool,
pre_compute_advantage: bool,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -61,6 +62,7 @@ def __init__(
self.lmbda = lmbda
self.scale_mapping = scale_mapping
self.use_tanh_normal = use_tanh_normal
self.pre_compute_advantage = pre_compute_advantage

#############################
# Overridden abstract methods
Expand Down Expand Up @@ -148,15 +150,17 @@ def _get_policy_for_loss(
spec=self.action_spec[group, "action"],
in_keys=[(group, "loc"), (group, "scale")],
out_keys=[(group, "action")],
distribution_class=IndependentNormal
if not self.use_tanh_normal
else TanhNormal,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {},
distribution_class=(
IndependentNormal if not self.use_tanh_normal else TanhNormal
),
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
),
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
Expand Down Expand Up @@ -220,14 +224,14 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
nested_reward_key,
batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)),
)

with torch.no_grad():
loss = self.get_loss_and_updater(group)[0]
loss.value_estimator(
batch,
params=loss.critic_network_params,
target_params=loss.target_critic_network_params,
)
if self.pre_compute_advantage:
with torch.no_grad():
loss = self.get_loss_and_updater(group)[0]
loss.value_estimator(
batch,
params=loss.critic_network_params,
target_params=loss.target_critic_network_params,
)

return batch

Expand Down Expand Up @@ -285,6 +289,7 @@ class IppoConfig(AlgorithmConfig):
lmbda: float = MISSING
scale_mapping: str = MISSING
use_tanh_normal: bool = MISSING
pre_compute_advantage: bool = MISSING

@staticmethod
def associated_class() -> Type[Algorithm]:
Expand Down
39 changes: 22 additions & 17 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
lmbda: float,
scale_mapping: str,
use_tanh_normal: bool,
pre_compute_advantage: bool,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -65,6 +66,7 @@ def __init__(
self.lmbda = lmbda
self.scale_mapping = scale_mapping
self.use_tanh_normal = use_tanh_normal
self.pre_compute_advantage = pre_compute_advantage

#############################
# Overridden abstract methods
Expand Down Expand Up @@ -152,15 +154,17 @@ def _get_policy_for_loss(
spec=self.action_spec[group, "action"],
in_keys=[(group, "loc"), (group, "scale")],
out_keys=[(group, "action")],
distribution_class=IndependentNormal
if not self.use_tanh_normal
else TanhNormal,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {},
distribution_class=(
IndependentNormal if not self.use_tanh_normal else TanhNormal
),
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
),
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
Expand Down Expand Up @@ -224,14 +228,14 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
nested_reward_key,
batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)),
)

with torch.no_grad():
loss = self.get_loss_and_updater(group)[0]
loss.value_estimator(
batch,
params=loss.critic_network_params,
target_params=loss.target_critic_network_params,
)
if self.pre_compute_advantage:
with torch.no_grad():
loss = self.get_loss_and_updater(group)[0]
loss.value_estimator(
batch,
params=loss.critic_network_params,
target_params=loss.target_critic_network_params,
)

return batch

Expand Down Expand Up @@ -321,6 +325,7 @@ class MappoConfig(AlgorithmConfig):
lmbda: float = MISSING
scale_mapping: str = MISSING
use_tanh_normal: bool = MISSING
pre_compute_advantage: bool = MISSING

@staticmethod
def associated_class() -> Type[Algorithm]:
Expand Down
1 change: 1 addition & 0 deletions benchmarl/conf/algorithm/ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ loss_critic_type: "l2"
lmbda: 0.9
scale_mapping: "biased_softplus_1.0"
use_tanh_normal: True
pre_compute_advantage: False
1 change: 1 addition & 0 deletions benchmarl/conf/algorithm/mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ loss_critic_type: "l2"
lmbda: 0.9
scale_mapping: "biased_softplus_1.0"
use_tanh_normal: True
pre_compute_advantage: False

0 comments on commit 8aea1a2

Please sign in to comment.