Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Aug 2, 2024
1 parent 00c3e02 commit 345a5bd
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,8 @@ def test_share_params_between_models(
or (isinstance(model_name, list) and model_name[0] != "gnn")
):
pytest.skip("gnn model needs agent dim as input")
if (
packaging.version.parse(torchrl.__version__).local is None
and "gru" in model_name
):
pytest.skip("gru model needs torchrl from github")
torch.manual_seed(1)

torch.manual_seed(0)

input_spec, output_spec = _get_input_and_output_specs(
centralised=centralised,
Expand Down Expand Up @@ -308,8 +304,6 @@ def test_share_params_between_models(
agent_group="agents",
action_spec=None,
)
for param, second_param in zip(model.parameters(), second_model.parameters()):
assert not torch.eq(param, second_param).any()
model.share_params_with(second_model)
for param, second_param in zip(model.parameters(), second_model.parameters()):
assert torch.eq(param, second_param).all()
Expand Down

0 comments on commit 345a5bd

Please sign in to comment.