diff --git a/README.md b/README.md index b7eaa3df..d6a5961c 100644 --- a/README.md +++ b/README.md @@ -263,7 +263,7 @@ agent group. Here is a table of the models implemented in BenchMARL | Name | Decentralized | Centralized with local inputs | Centralized with global input | |--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:| | [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes | -| [GNN](benchmarl/models/gnn.py) | Yes | No | No | +| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No | | [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes | And the ones that are _work in progress_ diff --git a/benchmarl/conf/model/layers/gnn.yaml b/benchmarl/conf/model/layers/gnn.yaml index e93a2fb4..fef327fd 100644 --- a/benchmarl/conf/model/layers/gnn.yaml +++ b/benchmarl/conf/model/layers/gnn.yaml @@ -6,3 +6,6 @@ self_loops: False gnn_class: torch_geometric.nn.conv.GraphConv gnn_kwargs: aggr: "add" + +position_key: null +velocity_key: null diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 76a0754e..97d7eb85 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -7,12 +7,15 @@ from __future__ import annotations import importlib +import inspect +import warnings from dataclasses import dataclass, MISSING from math import prod from typing import Optional, Type import torch from tensordict import TensorDictBase +from tensordict.utils import _unravel_key_to_tuple from torch import nn, Tensor from benchmarl.models.common import Model, ModelConfig @@ -20,24 +23,29 @@ _has_torch_geometric = importlib.util.find_spec("torch_geometric") is not None if _has_torch_geometric: import torch_geometric + from torch_geometric.transforms import BaseTransform -TOPOLOGY_TYPES = {"full", "empty"} + class _RelVel(BaseTransform): + """Transform that reads graph.vel and writes node1.vel - node2.vel in the edge attributes""" + def __init__(self): + pass -def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str): - if topology == "full": - adjacency = torch.ones(n_agents, n_agents, device=device, dtype=torch.long) - elif topology == "empty": - adjacency = torch.ones(n_agents, n_agents, device=device, dtype=torch.long) + def __call__(self, data): + (row, col), vel, pseudo = data.edge_index, data.vel, data.edge_attr - edge_index, _ = torch_geometric.utils.dense_to_sparse(adjacency) + cart = vel[row] - vel[col] + cart = cart.view(-1, 1) if cart.dim() == 1 else cart - if self_loops: - edge_index, _ = torch_geometric.utils.add_self_loops(edge_index) - else: - edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) + if pseudo is not None: + pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo + data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1) + else: + data.edge_attr = cart + return data - return edge_index + +TOPOLOGY_TYPES = {"full", "empty"} class Gnn(Model): @@ -50,6 +58,14 @@ class Gnn(Model): self_loops (str): Whether the resulting adjacency matrix will have self loops. gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class + position_key (str, optional): if provided, it will need to match a leaf key in the env observation spec + representing the agent position. This key will not be processed as a node feature, but it will used to construct + edge features. In particular it be used to compute relative positions (``pos_node_1 - pos_node_2``) and a + one-dimensional distance for all neighbours in the graph. + velocity_key (str, optional): if provided, it will need to match a leaf key in the env observation spec + representing the agent velocity. This key will not be processed as a node feature, but it will used to construct + edge features. In particular it be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours + in the graph. Examples: @@ -87,8 +103,6 @@ class Gnn(Model): ) experiment.run() - - """ def __init__( @@ -96,17 +110,42 @@ def __init__( topology: str, self_loops: bool, gnn_class: Type[torch_geometric.nn.MessagePassing], - gnn_kwargs: Optional[dict] = None, + gnn_kwargs: Optional[dict], + position_key: Optional[str], + velocity_key: Optional[str], **kwargs, ): self.topology = topology self.self_loops = self_loops + self.position_key = position_key + self.velocity_key = velocity_key super().__init__(**kwargs) + self.pos_features = sum( + [ + spec.shape[-1] + for key, spec in self.input_spec.items(True, True) + if _unravel_key_to_tuple(key)[-1] == position_key + ] + ) # Input keys ending with `position_key` + if self.pos_features > 0: + self.pos_features += 1 # We will add also 1-dimensional distance + self.vel_features = sum( + [ + spec.shape[-1] + for key, spec in self.input_spec.items(True, True) + if _unravel_key_to_tuple(key)[-1] == velocity_key + ] + ) # Input keys ending with `velocity_key` + self.edge_features = self.pos_features + self.vel_features self.input_features = sum( - [spec.shape[-1] for spec in self.input_spec.values(True, True)] - ) + [ + spec.shape[-1] + for key, spec in self.input_spec.items(True, True) + if _unravel_key_to_tuple(key)[-1] not in (velocity_key, position_key) + ] + ) # Input keys not ending with `velocity_key` and `position_key` self.output_features = self.output_leaf_spec.shape[-1] if gnn_kwargs is None: @@ -114,6 +153,21 @@ def __init__( gnn_kwargs.update( {"in_channels": self.input_features, "out_channels": self.output_features} ) + self.gnn_supports_edge_attrs = ( + "edge_dim" in inspect.getfullargspec(gnn_class).args + ) + if ( + self.position_key is not None or self.velocity_key is not None + ) and not self.gnn_supports_edge_attrs: + warnings.warn( + "Position key or velocity key provided but GNN class does not support edge attributes. " + "These input keys will be ignored. If instead you want to process them as node features, " + "just set them (position_key or velocity_key) to null." + ) + if ( + position_key is not None or velocity_key is not None + ) and self.gnn_supports_edge_attrs: + gnn_kwargs.update({"edge_dim": self.edge_features}) self.gnns = nn.ModuleList( [ @@ -135,13 +189,13 @@ def _perform_checks(self): raise ValueError( f"Got topology: {self.topology} but only available options are {TOPOLOGY_TYPES}" ) - if self.centralised: - raise ValueError("GNN model can only be used in non-centralised critics") + if not self.input_has_agent_dim: raise ValueError( "The GNN module is not compatible with input that does not have the agent dimension," "such as the global state in centralised critics. Please choose another critic model" "if your algorithm has a centralized critic and the task has a global state." + "If you are using the GNN in a centralized critic, it should be the first layer." ) input_shape = None @@ -176,88 +230,125 @@ def _perform_checks(self): def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: # Gather in_key - input = torch.cat([tensordict.get(in_key) for in_key in self.in_keys], dim=-1) + input = torch.cat( + [ + tensordict.get(in_key) + for in_key in self.in_keys + if _unravel_key_to_tuple(in_key)[-1] + not in (self.position_key, self.velocity_key) + ], + dim=-1, + ) + if self.position_key is not None: + pos = torch.cat( + [ + tensordict.get(in_key) + for in_key in self.in_keys + if _unravel_key_to_tuple(in_key)[-1] == self.position_key + ], + dim=-1, + ) + else: + pos = None + if self.velocity_key is not None: + vel = torch.cat( + [ + tensordict.get(in_key) + for in_key in self.in_keys + if _unravel_key_to_tuple(in_key)[-1] == self.velocity_key + ], + dim=-1, + ) + else: + vel = None batch_size = input.shape[:-2] - graph = batch_from_dense_to_ptg(x=input, edge_index=self.edge_index) + graph = _batch_from_dense_to_ptg( + x=input, edge_index=self.edge_index, pos=pos, vel=vel + ) + forward_gnn_params = { + "x": graph.x, + "edge_index": graph.edge_index, + } + if ( + self.position_key is not None or self.velocity_key is not None + ) and self.gnn_supports_edge_attrs: + forward_gnn_params.update({"edge_attr": graph.edge_attr}) if not self.share_params: - res = torch.stack( - [ - gnn(graph.x, graph.edge_index).view( - *batch_size, - self.n_agents, - self.output_features, - )[..., i, :] - for i, gnn in enumerate(self.gnns) - ], - dim=-2, - ) + if not self.centralised: + res = torch.stack( + [ + gnn(**forward_gnn_params).view( + *batch_size, + self.n_agents, + self.output_features, + )[..., i, :] + for i, gnn in enumerate(self.gnns) + ], + dim=-2, + ) + else: + res = torch.stack( + [ + gnn(**forward_gnn_params) + .view( + *batch_size, + self.n_agents, + self.output_features, + ) + .mean(dim=-2) # Mean pooling + for i, gnn in enumerate(self.gnns) + ], + dim=-2, + ) else: - res = self.gnns[0]( - graph.x, - graph.edge_index, - ).view(*batch_size, self.n_agents, self.output_features) + res = self.gnns[0](**forward_gnn_params).view( + *batch_size, self.n_agents, self.output_features + ) + if self.centralised: + res = res.mean(dim=-2) # Mean pooling tensordict.set(self.out_key, res) return tensordict -# class GnnKernel(nn.Module): -# def __init__(self, in_dim, out_dim, **cfg): -# super().__init__() -# -# gnn_types = {"GraphConv", "GATv2Conv", "GINEConv"} -# aggr_types = {"add", "mean", "max"} -# -# self.aggr = "add" -# self.gnn_type = "GraphConv" -# -# self.in_dim = in_dim -# self.out_dim = out_dim -# self.activation_fn = nn.Tanh -# -# if self.gnn_type == "GraphConv": -# self.gnn = GraphConv( -# self.in_dim, -# self.out_dim, -# aggr=self.aggr, -# ) -# elif self.gnn_type == "GATv2Conv": -# # Default adds self loops -# self.gnn = GATv2Conv( -# self.in_dim, -# self.out_dim, -# edge_dim=self.edge_features, -# fill_value=0.0, -# share_weights=True, -# add_self_loops=True, -# aggr=self.aggr, -# ) -# elif self.gnn_type == "GINEConv": -# self.gnn = GINEConv( -# nn=nn.Sequential( -# torch.nn.Linear(self.in_dim, self.out_dim), -# self.activation_fn(), -# ), -# edge_dim=self.edge_features, -# aggr=self.aggr, -# ) -# -# def forward(self, x, edge_index): -# out = self.gnn(x, edge_index) -# return out +def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str): + if topology == "full": + adjacency = torch.ones(n_agents, n_agents, device=device, dtype=torch.long) + edge_index, _ = torch_geometric.utils.dense_to_sparse(adjacency) + if not self_loops: + edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) + elif topology == "empty": + if self_loops: + edge_index = ( + torch.arange(n_agents, device=device, dtype=torch.long) + .unsqueeze(0) + .repeat(2, 1) + ) + else: + edge_index = torch.empty((2, 0), device=device, dtype=torch.long) + else: + raise ValueError(f"Topology {topology} not supported") + + return edge_index -def batch_from_dense_to_ptg( +def _batch_from_dense_to_ptg( x: Tensor, edge_index: Tensor, + pos: Tensor = None, + vel: Tensor = None, ) -> torch_geometric.data.Batch: batch_size = prod(x.shape[:-2]) n_agents = x.shape[-2] x = x.view(-1, x.shape[-1]) + if pos is not None: + pos = pos.view(-1, pos.shape[-1]) + if vel is not None: + vel = vel.view(-1, vel.shape[-1]) b = torch.arange(batch_size, device=x.device) @@ -265,6 +356,8 @@ def batch_from_dense_to_ptg( graphs.ptr = torch.arange(0, (batch_size + 1) * n_agents, n_agents) graphs.batch = torch.repeat_interleave(b, n_agents) graphs.x = x + graphs.pos = pos + graphs.vel = vel graphs.edge_attr = None n_edges = edge_index.shape[1] @@ -278,6 +371,11 @@ def batch_from_dense_to_ptg( graphs.edge_index = batch_edge_index graphs = graphs.to(x.device) + if pos is not None: + graphs = torch_geometric.transforms.Cartesian(norm=False)(graphs) + graphs = torch_geometric.transforms.Distance(norm=False)(graphs) + if vel is not None: + graphs = _RelVel()(graphs) return graphs @@ -292,6 +390,9 @@ class GnnConfig(ModelConfig): gnn_class: Type[torch_geometric.nn.MessagePassing] = MISSING gnn_kwargs: Optional[dict] = None + position_key: Optional[str] = None + velocity_key: Optional[str] = None + @staticmethod def associated_class(): return Gnn diff --git a/docs/source/concepts/components.rst b/docs/source/concepts/components.rst index b195465e..50dfe47d 100644 --- a/docs/source/concepts/components.rst +++ b/docs/source/concepts/components.rst @@ -112,7 +112,7 @@ agent group. Here is a table of the models implemented in BenchMARL +=================================+===============+===============================+===============================+ | :class:`~benchmarl.models.Mlp` | Yes | Yes | Yes | +---------------------------------+---------------+-------------------------------+-------------------------------+ - | :class:`~benchmarl.models.Gnn` | Yes | No | No | + | :class:`~benchmarl.models.Gnn` | Yes | Yes | No | +---------------------------------+---------------+-------------------------------+-------------------------------+ | :class:`~benchmarl.models.Cnn` | Yes | Yes | Yes | +---------------------------------+---------------+-------------------------------+-------------------------------+ diff --git a/test/test_models.py b/test/test_models.py index 5b28e3b8..c7671f56 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -3,14 +3,15 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - +import contextlib from typing import List import pytest import torch +import torch_geometric.nn from benchmarl.hydra_config import load_model_config_from_hydra -from benchmarl.models import model_config_registry +from benchmarl.models import GnnConfig, model_config_registry from benchmarl.models.common import output_has_agent_dim, SequenceModelConfig from hydra import compose, initialize @@ -75,8 +76,11 @@ def test_models_forward_shape( ): if not input_has_agent_dim and not centralised: pytest.skip() # this combination should never happen - if ("gnn" in model_name) and centralised: - pytest.skip("gnn model is always decentralized") + if ("gnn" in model_name) and ( + not input_has_agent_dim + or (isinstance(model_name, list) and model_name[0] != "gnn") + ): + pytest.skip("gnn model needs agent dim as input") torch.manual_seed(0) @@ -166,3 +170,104 @@ def test_models_forward_shape( input_td = input_spec.expand(batch_size).rand() out_td = model(input_td) assert output_spec.expand(batch_size).is_in(out_td) + + +class TestGnn: + @pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("position_key", ["pos", None]) + def test_gnn_edge_attrs( + self, + batch_size, + share_params, + position_key, + n_agents=3, + obs_size=4, + pos_size=2, + agent_goup="agents", + out_features=5, + ): + torch.manual_seed(0) + + multi_agent_obs = torch.rand((*batch_size, n_agents, obs_size)) + multi_agent_pos = torch.rand((*batch_size, n_agents, pos_size)) + input_spec = CompositeSpec( + { + agent_goup: CompositeSpec( + { + "observation": UnboundedContinuousTensorSpec( + shape=multi_agent_obs.shape[len(batch_size) :] + ), + "pos": UnboundedContinuousTensorSpec( + shape=multi_agent_obs.shape[len(batch_size) :] + ), + }, + shape=(n_agents,), + ) + } + ) + + output_spec = CompositeSpec( + { + agent_goup: CompositeSpec( + { + "out": UnboundedContinuousTensorSpec( + shape=(n_agents, out_features) + ) + }, + shape=(n_agents,), + ) + }, + ) + + # Test with correct stuff + gnn = GnnConfig( + topology="full", + self_loops=True, + gnn_class=torch_geometric.nn.GATv2Conv, + gnn_kwargs=None, + position_key=position_key, + ).get_model( + input_spec=input_spec, + output_spec=output_spec, + agent_group=agent_goup, + input_has_agent_dim=True, + n_agents=n_agents, + centralised=False, + share_params=share_params, + device="cpu", + action_spec=None, + ) + + obs_input = input_spec.expand(batch_size).rand() + output = gnn(obs_input) + assert output_spec.expand(batch_size).is_in(output) + + # Test with a GNN without edge_attrs + with ( + pytest.warns( + match="Position key or velocity key provided but GNN class does not support edge attributes*" + ) + if position_key is not None + else contextlib.nullcontext() + ): + gnn = GnnConfig( + topology="full", + self_loops=True, + gnn_class=torch_geometric.nn.GraphConv, + gnn_kwargs=None, + position_key=position_key, + ).get_model( + input_spec=input_spec, + output_spec=output_spec, + agent_group=agent_goup, + input_has_agent_dim=True, + n_agents=n_agents, + centralised=False, + share_params=share_params, + device="cpu", + action_spec=None, + ) + obs_input = input_spec.expand(batch_size).rand() + output = gnn(obs_input) + assert output_spec.expand(batch_size).is_in(output)