Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Sep 3, 2024
1 parent 363b728 commit a4752b6
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,19 @@ class Gnn(Model):
self_loops (str): Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
position_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent position. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env
(we suggest to use the "info" dict) representing the agent position. This key will be processed as a
node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features.
In particular, it will be used to compute relative positions (``pos_node_1 - pos_node_2``) and a
one-dimensional distance for all neighbours in the graph.
pos_features (int, optional): Needed when position_key is specified.
It has to match to the last element of the shape the tensor under position_key.
exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided,
wether to use it just to compute edge features or also include it in node features.
velocity_key (str, optional): if provided, it will need to match a leaf key in the env observation spec
representing the agent velocity. This key will not be processed as a node feature, but it will used to construct
edge features. In particular it be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours
in the graph.
velocity_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env
(we suggest to use the "info" dict) representing the agent velocity. This key will be processed as a node feature, and
it will be used to construct edge features. In particular, it will be used to compute relative velocities
(``vel_node_1 - vel_node_2``) for all neighbours in the graph.
vel_features (int, optional): Needed when velocity_key is specified.
It has to match to the last element of the shape the tensor under velocity_key.
edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph.
Expand Down

0 comments on commit a4752b6

Please sign in to comment.