From a4752b60432b80ac9ee1117dcb9b01301e33e204 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 3 Sep 2024 19:04:26 +0200 Subject: [PATCH] amend --- benchmarl/models/gnn.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index a53ae19a..5cb0f00e 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -59,18 +59,19 @@ class Gnn(Model): self_loops (str): Whether the resulting adjacency matrix will have self loops. gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class - position_key (str, optional): if provided, it will need to match a leaf key in the env observation spec - representing the agent position. This key will not be processed as a node feature, but it will used to construct - edge features. In particular it be used to compute relative positions (``pos_node_1 - pos_node_2``) and a + position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env + (we suggest to use the "info" dict) representing the agent position. This key will be processed as a + node feature (unless exclude_pos_from_node_features=True) and it will be used to construct edge features. + In particular, it will be used to compute relative positions (``pos_node_1 - pos_node_2``) and a one-dimensional distance for all neighbours in the graph. pos_features (int, optional): Needed when position_key is specified. It has to match to the last element of the shape the tensor under position_key. exclude_pos_from_node_features (optional, bool): If ``position_key`` is provided, wether to use it just to compute edge features or also include it in node features. - velocity_key (str, optional): if provided, it will need to match a leaf key in the env observation spec - representing the agent velocity. This key will not be processed as a node feature, but it will used to construct - edge features. In particular it be used to compute relative velocities (``vel_node_1 - vel_node_2``) for all neighbours - in the graph. + velocity_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env + (we suggest to use the "info" dict) representing the agent velocity. This key will be processed as a node feature, and + it will be used to construct edge features. In particular, it will be used to compute relative velocities + (``vel_node_1 - vel_node_2``) for all neighbours in the graph. vel_features (int, optional): Needed when velocity_key is specified. It has to match to the last element of the shape the tensor under velocity_key. edge_radius (float, optional): If topology is ``"from_pos"`` the radius to use to build the agent graph.