Skip to content

Commit

Permalink
[Feature] Improve GNN (#93)
Browse files Browse the repository at this point in the history
* gnn add edge features

* global pooling in gnns

* global pooling in gnns

* amend

* amend
  • Loading branch information
matteobettini authored Jun 11, 2024
1 parent 6c32503 commit d5f952f
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 87 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ agent group. Here is a table of the models implemented in BenchMARL
| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | No | No |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |

And the ones that are _work in progress_
Expand Down
3 changes: 3 additions & 0 deletions benchmarl/conf/model/layers/gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ self_loops: False
gnn_class: torch_geometric.nn.conv.GraphConv
gnn_kwargs:
aggr: "add"

position_key: null
velocity_key: null
Loading

0 comments on commit d5f952f

Please sign in to comment.