diff --git a/benchmarl/algorithms/isac.py b/benchmarl/algorithms/isac.py index 29afc88b..eb6845ee 100644 --- a/benchmarl/algorithms/isac.py +++ b/benchmarl/algorithms/isac.py @@ -28,6 +28,7 @@ def __init__( loss_function: str, delay_qvalue: bool, target_entropy: Union[float, str], + discrete_target_entropy_weight: float, alpha_init: float, min_alpha: Optional[float], max_alpha: Optional[float], @@ -41,6 +42,7 @@ def __init__( self.num_qvalue_nets = num_qvalue_nets self.loss_function = loss_function self.target_entropy = target_entropy + self.discrete_target_entropy_weight = discrete_target_entropy_weight self.alpha_init = alpha_init self.min_alpha = min_alpha self.max_alpha = max_alpha @@ -86,9 +88,10 @@ def _get_loss( alpha_init=self.alpha_init, min_alpha=self.min_alpha, max_alpha=self.max_alpha, - action_space=self.action_spec[group, "action"], + action_space=self.action_spec, fixed_alpha=self.fixed_alpha, target_entropy=self.target_entropy, + target_entropy_weight=self.discrete_target_entropy_weight, delay_qvalue=self.delay_qvalue, num_actions=self.action_spec[group, "action"].space.n, ) @@ -351,6 +354,7 @@ class IsacConfig(AlgorithmConfig): loss_function: str = MISSING delay_qvalue: bool = MISSING target_entropy: Union[float, str] = MISSING + discrete_target_entropy_weight: float = MISSING alpha_init: float = MISSING min_alpha: Optional[float] = MISSING diff --git a/benchmarl/algorithms/masac.py b/benchmarl/algorithms/masac.py index b084fe53..99bb6480 100644 --- a/benchmarl/algorithms/masac.py +++ b/benchmarl/algorithms/masac.py @@ -22,6 +22,7 @@ def __init__( loss_function: str, delay_qvalue: bool, target_entropy: Union[float, str], + discrete_target_entropy_weight: float, alpha_init: float, min_alpha: Optional[float], max_alpha: Optional[float], @@ -35,6 +36,7 @@ def __init__( self.num_qvalue_nets = num_qvalue_nets self.loss_function = loss_function self.target_entropy = target_entropy + self.discrete_target_entropy_weight = discrete_target_entropy_weight self.alpha_init = alpha_init self.min_alpha = min_alpha self.max_alpha = max_alpha @@ -83,6 +85,7 @@ def _get_loss( action_space=self.action_spec, fixed_alpha=self.fixed_alpha, target_entropy=self.target_entropy, + target_entropy_weight=self.discrete_target_entropy_weight, delay_qvalue=self.delay_qvalue, num_actions=self.action_spec[group, "action"].space.n, ) @@ -426,6 +429,7 @@ class MasacConfig(AlgorithmConfig): loss_function: str = MISSING delay_qvalue: bool = MISSING target_entropy: Union[float, str] = MISSING + discrete_target_entropy_weight: float = MISSING alpha_init: float = MISSING min_alpha: Optional[float] = MISSING diff --git a/benchmarl/conf/algorithm/isac.yaml b/benchmarl/conf/algorithm/isac.yaml index 82c5d660..2cff5eb3 100644 --- a/benchmarl/conf/algorithm/isac.yaml +++ b/benchmarl/conf/algorithm/isac.yaml @@ -9,6 +9,7 @@ num_qvalue_nets: 2 loss_function: "l2" delay_qvalue: True target_entropy: "auto" +discrete_target_entropy_weight: 0.2 alpha_init: 1.0 min_alpha: null diff --git a/benchmarl/conf/algorithm/masac.yaml b/benchmarl/conf/algorithm/masac.yaml index 08d3d778..650fcb5b 100644 --- a/benchmarl/conf/algorithm/masac.yaml +++ b/benchmarl/conf/algorithm/masac.yaml @@ -10,6 +10,7 @@ num_qvalue_nets: 2 loss_function: "l2" delay_qvalue: True target_entropy: "auto" +discrete_target_entropy_weight: 0.2 alpha_init: 1.0 min_alpha: null diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index 3eb9736f..2db533be 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -60,7 +60,7 @@ off_policy_collected_frames_per_batch: 1000 # Number of environments used for collection # If the environment is vectorized, this will be the number of batched environments. # Otherwise batching will be simulated and each env will be run sequentially. -off_policy_n_envs_per_worker: 10 +off_policy_n_envs_per_worker: 1 # This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over. off_policy_n_optimizer_steps: 1000 # Number of frames used for each off_policy_n_optimizer_steps when training off-policy algorithms