Skip to content

Commit

Permalink
[Conf] Discrete SAC entropy weight default
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 8, 2023
1 parent 0b56e32 commit 1eaf349
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 2 deletions.
6 changes: 5 additions & 1 deletion benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
loss_function: str,
delay_qvalue: bool,
target_entropy: Union[float, str],
discrete_target_entropy_weight: float,
alpha_init: float,
min_alpha: Optional[float],
max_alpha: Optional[float],
Expand All @@ -41,6 +42,7 @@ def __init__(
self.num_qvalue_nets = num_qvalue_nets
self.loss_function = loss_function
self.target_entropy = target_entropy
self.discrete_target_entropy_weight = discrete_target_entropy_weight
self.alpha_init = alpha_init
self.min_alpha = min_alpha
self.max_alpha = max_alpha
Expand Down Expand Up @@ -86,9 +88,10 @@ def _get_loss(
alpha_init=self.alpha_init,
min_alpha=self.min_alpha,
max_alpha=self.max_alpha,
action_space=self.action_spec[group, "action"],
action_space=self.action_spec,
fixed_alpha=self.fixed_alpha,
target_entropy=self.target_entropy,
target_entropy_weight=self.discrete_target_entropy_weight,
delay_qvalue=self.delay_qvalue,
num_actions=self.action_spec[group, "action"].space.n,
)
Expand Down Expand Up @@ -351,6 +354,7 @@ class IsacConfig(AlgorithmConfig):
loss_function: str = MISSING
delay_qvalue: bool = MISSING
target_entropy: Union[float, str] = MISSING
discrete_target_entropy_weight: float = MISSING

alpha_init: float = MISSING
min_alpha: Optional[float] = MISSING
Expand Down
4 changes: 4 additions & 0 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
loss_function: str,
delay_qvalue: bool,
target_entropy: Union[float, str],
discrete_target_entropy_weight: float,
alpha_init: float,
min_alpha: Optional[float],
max_alpha: Optional[float],
Expand All @@ -35,6 +36,7 @@ def __init__(
self.num_qvalue_nets = num_qvalue_nets
self.loss_function = loss_function
self.target_entropy = target_entropy
self.discrete_target_entropy_weight = discrete_target_entropy_weight
self.alpha_init = alpha_init
self.min_alpha = min_alpha
self.max_alpha = max_alpha
Expand Down Expand Up @@ -83,6 +85,7 @@ def _get_loss(
action_space=self.action_spec,
fixed_alpha=self.fixed_alpha,
target_entropy=self.target_entropy,
target_entropy_weight=self.discrete_target_entropy_weight,
delay_qvalue=self.delay_qvalue,
num_actions=self.action_spec[group, "action"].space.n,
)
Expand Down Expand Up @@ -426,6 +429,7 @@ class MasacConfig(AlgorithmConfig):
loss_function: str = MISSING
delay_qvalue: bool = MISSING
target_entropy: Union[float, str] = MISSING
discrete_target_entropy_weight: float = MISSING

alpha_init: float = MISSING
min_alpha: Optional[float] = MISSING
Expand Down
1 change: 1 addition & 0 deletions benchmarl/conf/algorithm/isac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ num_qvalue_nets: 2
loss_function: "l2"
delay_qvalue: True
target_entropy: "auto"
discrete_target_entropy_weight: 0.2

alpha_init: 1.0
min_alpha: null
Expand Down
1 change: 1 addition & 0 deletions benchmarl/conf/algorithm/masac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ num_qvalue_nets: 2
loss_function: "l2"
delay_qvalue: True
target_entropy: "auto"
discrete_target_entropy_weight: 0.2

alpha_init: 1.0
min_alpha: null
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ off_policy_collected_frames_per_batch: 1000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially.
off_policy_n_envs_per_worker: 10
off_policy_n_envs_per_worker: 1
# This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over.
off_policy_n_optimizer_steps: 1000
# Number of frames used for each off_policy_n_optimizer_steps when training off-policy algorithms
Expand Down

0 comments on commit 1eaf349

Please sign in to comment.