diff --git a/autoemulate/emulators/neural_net_torch.py b/autoemulate/emulators/neural_net_torch.py index 3fa8ec30..a0691505 100644 --- a/autoemulate/emulators/neural_net_torch.py +++ b/autoemulate/emulators/neural_net_torch.py @@ -56,7 +56,7 @@ def __init__( max_epochs: int = 1, module__input_size: int = 2, module__output_size: int = 1, - optimizer__weight_decay: float = 0.0001, + optimizer__weight_decay: float = 0.0, iterator_train__shuffle: bool = True, callbacks: List[Callback] = [InputShapeSetter()], train_split: bool = False, # to run cross_validate without splitting the data diff --git a/autoemulate/emulators/neural_networks/mlp.py b/autoemulate/emulators/neural_networks/mlp.py index dd7ba242..f7d94e1f 100644 --- a/autoemulate/emulators/neural_networks/mlp.py +++ b/autoemulate/emulators/neural_networks/mlp.py @@ -18,7 +18,8 @@ def __init__( input_size: int = None, output_size: int = None, random_state: int = None, - hidden_sizes: Tuple[int] = (100,), + hidden_layers: int = 1, + hidden_size: int = 100, hidden_activation: Tuple[callable] = nn.ReLU, ): super(MLPModule, self).__init__( @@ -28,7 +29,8 @@ def __init__( random_state=random_state, ) modules = [] - for hidden_size in hidden_sizes: + assert hidden_layers >= 1 + for _ in range(hidden_layers): modules.append(nn.Linear(in_features=input_size, out_features=hidden_size)) modules.append(hidden_activation()) input_size = hidden_size @@ -37,18 +39,23 @@ def __init__( def get_grid_params(self, search_type: str = "random"): param_space = { - "lr": loguniform(1e-06, 0.01), "max_epochs": np.arange(10, 110, 10).tolist(), "batch_size": np.arange(2, 128, 2).tolist(), - "module__hidden_sizes": [(50,), (100,), (100, 100), (100, 200), (200, 200)], - "module__hidden_activation": [nn.ReLU, nn.Tanh, nn.Sigmoid, nn.GELU], - "optimizer__weight_decay": (1 / 10 ** np.arange(1, 10)).tolist(), + "module__hidden_layers": np.arange(1, 4).tolist(), + "module__hidden_size": np.arange(50, 250, 50).tolist(), + "module__hidden_activation": [ + nn.ReLU, + nn.Tanh, + nn.Sigmoid, + nn.GELU, + ], + "optimizer__weight_decay": (1 / 10 ** np.arange(1, 9)).tolist(), } match search_type: case "random": - pass + param_space |= {"lr": loguniform(1e-06, 1e-2)} case "bayes": - pass + param_space |= {"lr": Real(1e-06, 1e-2, prior="log-uniform")} case _: raise ValueError(f"Invalid search type: {search_type}")