diff --git a/benchmarl/models/gru.py b/benchmarl/models/gru.py index d54ae454..b8d97464 100644 --- a/benchmarl/models/gru.py +++ b/benchmarl/models/gru.py @@ -161,7 +161,7 @@ def forward( h_0=None, ): # Input and output always have the multiagent dimension - # Hidden state only has it when not centralised + # Hidden states always have it apart from when it is centralized and share params # is_init never has it assert is_init is not None, "We need to pass is_init" @@ -196,7 +196,7 @@ def forward( is_init = is_init.unsqueeze(-2).expand(batch, seq, self.n_agents, 1) if h_0 is None: - if self.centralised: + if self.centralised and self.share_params: shape = ( batch, self.n_layers, @@ -237,7 +237,7 @@ def run_net(self, input, is_init, h_0): if self.centralised: output, h_n = self.vmap_func_module( self._empty_gru, - (0, None, None, None), + (0, None, None, -3), (-2, -3), )(self.params, input, is_init, h_0) else: diff --git a/benchmarl/models/lstm.py b/benchmarl/models/lstm.py index 197c3e32..d92ebeda 100644 --- a/benchmarl/models/lstm.py +++ b/benchmarl/models/lstm.py @@ -162,7 +162,7 @@ def forward( c_0=None, ): # Input and output always have the multiagent dimension - # Hidden state only has it when not centralised + # Hidden states always have it apart from when it is centralized and share params # is_init never has it assert is_init is not None, "We need to pass is_init" @@ -199,7 +199,7 @@ def forward( is_init = is_init.unsqueeze(-2).expand(batch, seq, self.n_agents, 1) if h_0 is None: - if self.centralised: + if self.centralised and self.share_params: shape = ( batch, self.n_layers, @@ -242,7 +242,7 @@ def run_net(self, input, is_init, h_0, c_0): if self.centralised: output, h_n, c_n = self.vmap_func_module( self._empty_lstm, - (0, None, None, None, None), + (0, None, None, -3, -3), (-2, -3, -3), )(self.params, input, is_init, h_0, c_0) else: diff --git a/test/test_models.py b/test/test_models.py index 2545057f..8cc75691 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -188,6 +188,8 @@ def test_models_forward_shape( share_params=share_params, n_agents=n_agents, ) + if packaging.version.parse(torchrl.__version__).local is None and config.is_rnn: + pytest.skip("rnn model needs torchrl from github") if centralised: config.is_critic = True @@ -282,6 +284,8 @@ def test_share_params_between_models( config = model_config_registry[model_name].get_from_yaml() if centralised: config.is_critic = True + if packaging.version.parse(torchrl.__version__).local is None and config.is_rnn: + pytest.skip("rnn model needs torchrl from github") model = config.get_model( input_spec=input_spec, output_spec=output_spec,