Skip to content

Commit

Permalink
add module__hidden_activation
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanlimy committed Feb 12, 2024
1 parent 77940d9 commit f24fe6c
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions autoemulate/emulators/neural_networks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
output_size: int = None,
random_state: int = None,
hidden_sizes: Tuple[int] = (100,),
hidden_activation: Tuple[callable] = nn.ReLU,
):
super(MLPModule, self).__init__(
module_name="mlp",
Expand All @@ -29,7 +30,7 @@ def __init__(
modules = []
for hidden_size in hidden_sizes:
modules.append(nn.Linear(in_features=input_size, out_features=hidden_size))
modules.append(nn.ReLU())
modules.append(hidden_activation())
input_size = hidden_size
modules.append(nn.Linear(in_features=input_size, out_features=output_size))
self.model = nn.Sequential(*modules)
Expand All @@ -38,18 +39,10 @@ def get_grid_params(self, search_type: str = "random"):
param_space = {
"lr": loguniform(1e-06, 0.01),
"max_epochs": np.arange(10, 110, 10).tolist(),
"module__hidden_sizes": [
(32,),
(64,),
(128,),
(128, 128),
(128, 256),
(256, 256),
(512, 512),
(128, 128, 128),
],
"batch_size": np.arange(2, 128, 2).tolist(),
"optimizer__weight_decay": loguniform(1e-8, 0.1),
"module__hidden_sizes": [(50,), (100,), (100, 100), (100, 200), (200, 200)],
"module__hidden_activation": [nn.ReLU, nn.Tanh, nn.Sigmoid, nn.GELU],
"optimizer__weight_decay": (1 / 10 ** np.arange(1, 10)).tolist(),
}
match search_type:
case "random":
Expand Down

0 comments on commit f24fe6c

Please sign in to comment.