Skip to content

Commit

Permalink
[Model] DeepSets (#96)
Browse files Browse the repository at this point in the history
* amend

* docs

* nits

* nits

* nits

* gnn radius

* gnn radius

* amend

* amend

* amend

* amend

* amend

* docs

* docs
  • Loading branch information
matteobettini authored Jun 13, 2024
1 parent 70d9ec7 commit d5b0f51
Show file tree
Hide file tree
Showing 8 changed files with 500 additions and 29 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,12 @@ when requested, as critics. We provide a set of base models (layers) and a Seque
different layers. All the models can be used with or without parameter sharing within an
agent group. Here is a table of the models implemented in BenchMARL

| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |
| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |
| [Deepsets](benchmarl/models/deepsets.py) | Yes | Yes | Yes |

And the ones that are _work in progress_

Expand Down
9 changes: 9 additions & 0 deletions benchmarl/conf/model/layers/deepsets.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

name: deepsets

aggr: "sum"
local_nn_num_cells: [128, 128]
local_nn_activation_class: torch.nn.Tanh
out_features_local_nn: 256
global_nn_num_cells: [256, 256]
global_nn_activation_class: torch.nn.Tanh
19 changes: 17 additions & 2 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,24 @@

from .cnn import Cnn, CnnConfig
from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
from .deepsets import Deepsets, DeepsetsConfig
from .gnn import Gnn, GnnConfig
from .mlp import Mlp, MlpConfig

classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig", "Cnn", "CnnConfig"]
classes = [
"Mlp",
"MlpConfig",
"Gnn",
"GnnConfig",
"Cnn",
"CnnConfig",
"Deepsets",
"DeepsetsConfig",
]

model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig, "cnn": CnnConfig}
model_config_registry = {
"mlp": MlpConfig,
"gnn": GnnConfig,
"cnn": CnnConfig,
"deepsets": DeepsetsConfig,
}
Loading

0 comments on commit d5b0f51

Please sign in to comment.