Skip to content

Commit

Permalink
[Feature] Share parameters between models (#95)
Browse files Browse the repository at this point in the history
* share params models

* warn

* docs
  • Loading branch information
matteobettini authored Jun 12, 2024
1 parent f4d006e commit a5c629b
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 61 deletions.
25 changes: 24 additions & 1 deletion benchmarl/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#

import pathlib

import warnings
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, List, Optional, Sequence
Expand Down Expand Up @@ -157,6 +157,29 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# _check_spec(tensordict, self.output_spec)
return tensordict

def share_params_with(self, other_model):
"""Share paramters with another identical model model.
This function modifies in-place the parameters of ``other_model`` to reference the parameters of ``self``
Args:
other_model (Model): the model that will share the parameters of ``self``.
"""
if (
self.share_params != other_model.share_params
or self.centralised != other_model.centralised
or self.input_has_agent_dim != other_model.input_has_agent_dim
or self.input_spec != other_model.input_spec
or self.output_spec != other_model.output_spec
):
raise warnings.warn(
"Sharing parameters with models that are not identical. "
"This might result in unintended behavior or error."
)
for param, other_param in zip(self.parameters(), other_model.parameters()):
other_param.data[:] = param.data

###############################
# Abstract methods to implement
###############################
Expand Down
219 changes: 159 additions & 60 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,76 @@
from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec


def _get_input_and_output_specs(
centralised,
input_has_agent_dim,
model_name,
share_params,
n_agents,
in_features=2,
out_features=4,
x=12,
y=12,
):

if model_name == "cnn":
multi_agent_input_shape = (n_agents, x, y, in_features)
single_agent_input_shape = (x, y, in_features)
else:
multi_agent_input_shape = (n_agents, in_features)
single_agent_input_shape = in_features

other_multi_agent_input_shape = (n_agents, in_features)
other_single_agent_input_shape = in_features

if input_has_agent_dim:
input_spec = CompositeSpec(
{
"agents": CompositeSpec(
{
"observation": UnboundedContinuousTensorSpec(
shape=multi_agent_input_shape
),
"other": UnboundedContinuousTensorSpec(
shape=other_multi_agent_input_shape
),
},
shape=(n_agents,),
)
}
)
else:
input_spec = CompositeSpec(
{
"observation": UnboundedContinuousTensorSpec(
shape=single_agent_input_shape
),
"other": UnboundedContinuousTensorSpec(
shape=other_single_agent_input_shape
),
},
)

if output_has_agent_dim(centralised=centralised, share_params=share_params):
output_spec = CompositeSpec(
{
"agents": CompositeSpec(
{
"out": UnboundedContinuousTensorSpec(
shape=(n_agents, out_features)
)
},
shape=(n_agents,),
)
},
)
else:
output_spec = CompositeSpec(
{"out": UnboundedContinuousTensorSpec(shape=(out_features,))},
)
return input_spec, output_spec


@pytest.mark.parametrize("model_name", model_config_registry.keys())
def test_loading_simple_models(model_name):
with initialize(version_base=None, config_path="../benchmarl/conf"):
Expand Down Expand Up @@ -72,7 +142,7 @@ def test_loading_sequence_models(model_name, intermediate_size=10):
],
)
def test_models_forward_shape(
share_params, centralised, input_has_agent_dim, model_name, batch_size
share_params, centralised, input_has_agent_dim, model_name, batch_size, n_agents=3
):
if not input_has_agent_dim and not centralised:
pytest.skip() # this combination should never happen
Expand All @@ -94,68 +164,84 @@ def test_models_forward_shape(
else:
config = model_config_registry[model_name].get_from_yaml()

n_agents = 2
x = 12
y = 12
channels = 3
out_features = 4
input_spec, output_spec = _get_input_and_output_specs(
centralised=centralised,
input_has_agent_dim=input_has_agent_dim,
model_name=model_name if isinstance(model_name, str) else model_name[0],
share_params=share_params,
n_agents=n_agents,
)

if "cnn" in model_name:
multi_agent_tensor = torch.rand((*batch_size, n_agents, x, y, channels))
single_agent_tensor = torch.rand((*batch_size, x, y, channels))
else:
multi_agent_tensor = torch.rand((*batch_size, n_agents, channels))
single_agent_tensor = torch.rand((*batch_size, channels))
model = config.get_model(
input_spec=input_spec,
output_spec=output_spec,
share_params=share_params,
centralised=centralised,
input_has_agent_dim=input_has_agent_dim,
n_agents=n_agents,
device="cpu",
agent_group="agents",
action_spec=None,
)
input_td = input_spec.expand(batch_size).rand()
out_td = model(input_td)
assert output_spec.expand(batch_size).is_in(out_td)

other_multi_agent_tensor = torch.rand((*batch_size, n_agents, channels))
other_single_agent_tensor = torch.rand((*batch_size, channels))

if input_has_agent_dim:
input_spec = CompositeSpec(
{
"agents": CompositeSpec(
{
"observation": UnboundedContinuousTensorSpec(
shape=multi_agent_tensor.shape[len(batch_size) :]
),
"other": UnboundedContinuousTensorSpec(
shape=other_multi_agent_tensor.shape[len(batch_size) :]
),
},
shape=(n_agents,),
)
}
)
else:
input_spec = CompositeSpec(
{
"observation": UnboundedContinuousTensorSpec(
shape=single_agent_tensor.shape[len(batch_size) :]
),
"other": UnboundedContinuousTensorSpec(
shape=other_single_agent_tensor.shape[len(batch_size) :]
),
},
)
@pytest.mark.parametrize("input_has_agent_dim", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize(
"model_name",
[
*model_config_registry.keys(),
["cnn", "gnn", "mlp"],
["cnn", "mlp", "gnn"],
["cnn", "mlp"],
],
)
@pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)])
def test_share_params_between_models(
share_params,
centralised,
input_has_agent_dim,
model_name,
batch_size,
n_agents=3,
):
if not input_has_agent_dim and not centralised:
pytest.skip() # this combination should never happen
if ("gnn" in model_name) and (
not input_has_agent_dim
or (isinstance(model_name, list) and model_name[0] != "gnn")
):
pytest.skip("gnn model needs agent dim as input")
torch.manual_seed(0)

if output_has_agent_dim(centralised=centralised, share_params=share_params):
output_spec = CompositeSpec(
{
"agents": CompositeSpec(
{
"out": UnboundedContinuousTensorSpec(
shape=(n_agents, out_features)
)
},
shape=(n_agents,),
)
},
input_spec, output_spec = _get_input_and_output_specs(
centralised=centralised,
input_has_agent_dim=input_has_agent_dim,
model_name=model_name if isinstance(model_name, str) else model_name[0],
share_params=share_params,
n_agents=n_agents,
)
input_spec2, output_spec2 = _get_input_and_output_specs(
centralised=centralised,
input_has_agent_dim=input_has_agent_dim,
model_name=model_name if isinstance(model_name, str) else model_name[0],
share_params=share_params,
n_agents=n_agents,
)

if isinstance(model_name, List):
config = SequenceModelConfig(
model_configs=[
model_config_registry[config].get_from_yaml() for config in model_name
],
intermediate_sizes=[4] * (len(model_name) - 1),
)
else:
output_spec = CompositeSpec(
{"out": UnboundedContinuousTensorSpec(shape=(out_features,))},
)

config = model_config_registry[model_name].get_from_yaml()
model = config.get_model(
input_spec=input_spec,
output_spec=output_spec,
Expand All @@ -167,9 +253,22 @@ def test_models_forward_shape(
agent_group="agents",
action_spec=None,
)
input_td = input_spec.expand(batch_size).rand()
out_td = model(input_td)
assert output_spec.expand(batch_size).is_in(out_td)
second_model = config.get_model(
input_spec=input_spec2,
output_spec=output_spec2,
share_params=share_params,
centralised=centralised,
input_has_agent_dim=input_has_agent_dim,
n_agents=n_agents,
device="cpu",
agent_group="agents",
action_spec=None,
)
for param, second_param in zip(model.parameters(), second_model.parameters()):
assert not torch.eq(param, second_param).any()
model.share_params_with(second_model)
for param, second_param in zip(model.parameters(), second_model.parameters()):
assert torch.eq(param, second_param).all()


class TestGnn:
Expand Down

0 comments on commit a5c629b

Please sign in to comment.