Skip to content

Commit

Permalink
gnn radius
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jun 13, 2024
1 parent 8cc40a2 commit d0e766c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 15 deletions.
3 changes: 3 additions & 0 deletions benchmarl/conf/model/layers/gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ gnn_kwargs:

position_key: null
velocity_key: null

exclude_pos_from_node_features: False
edge_radius: null
62 changes: 47 additions & 15 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __call__(self, data):
return data


TOPOLOGY_TYPES = {"full", "empty"}
TOPOLOGY_TYPES = {"full", "empty", "from_pos"}


class Gnn(Model):
Expand All @@ -54,18 +54,23 @@ class Gnn(Model):
GNN models can be used as "decentralized" actors or critics.
Args:
topology (str): Topology of the graph adjacency matrix. Options: "full", "empty".
topology (str): Topology of the graph adjacency matrix. Options: "full", "empty", "from_pos". "from_pos" builds
the topology dynamically based on ``position_key`` and ``edge_radius``.
self_loops (str): Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
position_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent position. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
one-dimensional distance for all neighbours in the graph.
exclude_pos_from_node_features (bool): If ``position_key`` is provided,
wether to use it just to compute edge features or also include it in node features.
velocity_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent velocity. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours
in the graph.
edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph.
Agents within this radius distance will be neighnours.
Examples:
Expand Down Expand Up @@ -112,13 +117,17 @@ def __init__(
gnn_class: Type[torch_geometric.nn.MessagePassing],
gnn_kwargs: Optional[dict],
position_key: Optional[str],
exclude_pos_from_node_features: bool,
velocity_key: Optional[str],
edge_radius: Optional[float],
**kwargs,
):
self.topology = topology
self.self_loops = self_loops
self.position_key = position_key
self.velocity_key = velocity_key
self.exclude_pos_from_node_features = exclude_pos_from_node_features
self.edge_radius = edge_radius

super().__init__(**kwargs)

Expand All @@ -143,7 +152,8 @@ def __init__(
[
spec.shape[-1]
for key, spec in self.input_spec.items(True, True)
if _unravel_key_to_tuple(key)[-1] not in (velocity_key, position_key)
if _unravel_key_to_tuple(key)[-1]
not in ((position_key) if self.exclude_pos_from_node_features else ())
]
) # Input keys not ending with `velocity_key` and `position_key`
self.output_features = self.output_leaf_spec.shape[-1]
Expand Down Expand Up @@ -189,6 +199,8 @@ def _perform_checks(self):
raise ValueError(
f"Got topology: {self.topology} but only available options are {TOPOLOGY_TYPES}"
)
if self.topology == "from_pos" and self.position_key is None:
raise ValueError("If topology is from_pos, position_key must be provided")

if not self.input_has_agent_dim:
raise ValueError(
Expand Down Expand Up @@ -233,7 +245,9 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict.get(in_key)
for in_key in self.in_keys
if _unravel_key_to_tuple(in_key)[-1]
not in (self.position_key, self.velocity_key)
not in (
(self.position_key) if self.exclude_pos_from_node_features else ()
)
],
dim=-1,
)
Expand Down Expand Up @@ -263,7 +277,12 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
batch_size = input.shape[:-2]

graph = _batch_from_dense_to_ptg(
x=input, edge_index=self.edge_index, pos=pos, vel=vel
x=input,
edge_index=self.edge_index,
pos=pos,
vel=vel,
self_loops=self.self_loops,
edge_radius=self.edge_radius,
)
forward_gnn_params = {
"x": graph.x,
Expand Down Expand Up @@ -328,6 +347,8 @@ def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str)
)
else:
edge_index = torch.empty((2, 0), device=device, dtype=torch.long)
elif topology == "from_pos":
edge_index = None
else:
raise ValueError(f"Topology {topology} not supported")

Expand All @@ -336,9 +357,11 @@ def _get_edge_index(topology: str, self_loops: bool, n_agents: int, device: str)

def _batch_from_dense_to_ptg(
x: Tensor,
edge_index: Tensor,
edge_index: Optional[Tensor],
self_loops: bool,
pos: Tensor = None,
vel: Tensor = None,
edge_radius: Optional[float] = None,
) -> torch_geometric.data.Batch:
batch_size = prod(x.shape[:-2])
n_agents = x.shape[-2]
Expand All @@ -358,15 +381,22 @@ def _batch_from_dense_to_ptg(
graphs.vel = vel
graphs.edge_attr = None

n_edges = edge_index.shape[1]
# Tensor of shape [batch_size * n_edges]
# in which edges corresponding to the same graph have the same index.
batch = torch.repeat_interleave(b, n_edges)
# Edge index for the batched graphs of shape [2, n_edges * batch_size]
# we sum to each batch an offset of batch_num * n_agents to make sure that
# the adjacency matrices remain independent
batch_edge_index = edge_index.repeat(1, batch_size) + batch * n_agents
graphs.edge_index = batch_edge_index
if edge_index is not None:
n_edges = edge_index.shape[1]
# Tensor of shape [batch_size * n_edges]
# in which edges corresponding to the same graph have the same index.
batch = torch.repeat_interleave(b, n_edges)
# Edge index for the batched graphs of shape [2, n_edges * batch_size]
# we sum to each batch an offset of batch_num * n_agents to make sure that
# the adjacency matrices remain independent
batch_edge_index = edge_index.repeat(1, batch_size) + batch * n_agents
graphs.edge_index = batch_edge_index
else:
if pos is None:
raise RuntimeError("from_pos topology needs positions as input")
graphs.edge_index = torch_geometric.nn.pool.radius_graph(
graphs.pos, batch=graphs.batch, r=edge_radius, loop=self_loops
)

graphs = graphs.to(x.device)
if pos is not None:
Expand All @@ -389,7 +419,9 @@ class GnnConfig(ModelConfig):
gnn_kwargs: Optional[dict] = None

position_key: Optional[str] = None
exclude_pos_from_node_features: bool = MISSING
velocity_key: Optional[str] = None
edge_radius: Optional[float] = None

@staticmethod
def associated_class():
Expand Down

0 comments on commit d0e766c

Please sign in to comment.