diff --git a/test/test_models.py b/test/test_models.py index 8cc75691..eb71fc89 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -258,12 +258,8 @@ def test_share_params_between_models( or (isinstance(model_name, list) and model_name[0] != "gnn") ): pytest.skip("gnn model needs agent dim as input") - if ( - packaging.version.parse(torchrl.__version__).local is None - and "gru" in model_name - ): - pytest.skip("gru model needs torchrl from github") - torch.manual_seed(1) + + torch.manual_seed(0) input_spec, output_spec = _get_input_and_output_specs( centralised=centralised, @@ -308,8 +304,6 @@ def test_share_params_between_models( agent_group="agents", action_spec=None, ) - for param, second_param in zip(model.parameters(), second_model.parameters()): - assert not torch.eq(param, second_param).any() model.share_params_with(second_model) for param, second_param in zip(model.parameters(), second_model.parameters()): assert torch.eq(param, second_param).all()