Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Aug 2, 2024
1 parent a7d0627 commit 00c3e02
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
6 changes: 3 additions & 3 deletions benchmarl/models/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def forward(
h_0=None,
):
# Input and output always have the multiagent dimension
# Hidden state only has it when not centralised
# Hidden states always have it apart from when it is centralized and share params
# is_init never has it

assert is_init is not None, "We need to pass is_init"
Expand Down Expand Up @@ -196,7 +196,7 @@ def forward(
is_init = is_init.unsqueeze(-2).expand(batch, seq, self.n_agents, 1)

if h_0 is None:
if self.centralised:
if self.centralised and self.share_params:
shape = (
batch,
self.n_layers,
Expand Down Expand Up @@ -237,7 +237,7 @@ def run_net(self, input, is_init, h_0):
if self.centralised:
output, h_n = self.vmap_func_module(
self._empty_gru,
(0, None, None, None),
(0, None, None, -3),
(-2, -3),
)(self.params, input, is_init, h_0)
else:
Expand Down
6 changes: 3 additions & 3 deletions benchmarl/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(
c_0=None,
):
# Input and output always have the multiagent dimension
# Hidden state only has it when not centralised
# Hidden states always have it apart from when it is centralized and share params
# is_init never has it

assert is_init is not None, "We need to pass is_init"
Expand Down Expand Up @@ -199,7 +199,7 @@ def forward(
is_init = is_init.unsqueeze(-2).expand(batch, seq, self.n_agents, 1)

if h_0 is None:
if self.centralised:
if self.centralised and self.share_params:
shape = (
batch,
self.n_layers,
Expand Down Expand Up @@ -242,7 +242,7 @@ def run_net(self, input, is_init, h_0, c_0):
if self.centralised:
output, h_n, c_n = self.vmap_func_module(
self._empty_lstm,
(0, None, None, None, None),
(0, None, None, -3, -3),
(-2, -3, -3),
)(self.params, input, is_init, h_0, c_0)
else:
Expand Down
4 changes: 4 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def test_models_forward_shape(
share_params=share_params,
n_agents=n_agents,
)
if packaging.version.parse(torchrl.__version__).local is None and config.is_rnn:
pytest.skip("rnn model needs torchrl from github")

if centralised:
config.is_critic = True
Expand Down Expand Up @@ -282,6 +284,8 @@ def test_share_params_between_models(
config = model_config_registry[model_name].get_from_yaml()
if centralised:
config.is_critic = True
if packaging.version.parse(torchrl.__version__).local is None and config.is_rnn:
pytest.skip("rnn model needs torchrl from github")
model = config.get_model(
input_spec=input_spec,
output_spec=output_spec,
Expand Down

0 comments on commit 00c3e02

Please sign in to comment.