Skip to content

Commit

Permalink
remove usage of global variable
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanlimy committed Feb 9, 2024
1 parent 66823d2 commit e84b172
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 48 deletions.
4 changes: 2 additions & 2 deletions autoemulate/emulators/neural_networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .mlp import MLPModule
from .neural_networks import get_module
from .base import TorchModule
from .get_module import get_module
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,6 @@

from autoemulate.utils import set_random_seed

_MODULES = dict()


def register(name):
def add_to_dict(fn):
global _MODULES
_MODULES[name] = fn
return fn

return add_to_dict


class TorchModule(nn.Module):
"""
Expand All @@ -40,16 +29,3 @@ def get_grid_params(self, search_type: str = "random"):

def forward(self, X: torch.Tensor):
raise NotImplementedError("forward method not implemented.")


def get_module(module: str | TorchModule, module_args) -> TorchModule:
"""
Return the module instance for NeuralNetRegressor. If `module` is a string,
then initialize a TorchModule with the same registered name. If `module` is
already a TorchModule, then return it as is.
"""
if isinstance(module, TorchModule):
return module
if module not in _MODULES:
raise NotImplementedError(f"Module {module} not implemented.")
return _MODULES[module](**module_args)
17 changes: 17 additions & 0 deletions autoemulate/emulators/neural_networks/get_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from autoemulate.emulators.neural_networks import TorchModule
from autoemulate.emulators.neural_networks.mlp import MLPModule


def get_module(module: str | TorchModule, module_args) -> TorchModule:
"""
Return the module instance for NeuralNetRegressor. If `module` is a string,
then initialize a TorchModule with the same registered name. If `module` is
already a TorchModule, then return it as is.
"""
if not isinstance(module, TorchModule):
match module:
case "mlp":
module = MLPModule(**module_args)
case _:
raise NotImplementedError(f"Module {module} not implemented.")
return module
40 changes: 18 additions & 22 deletions autoemulate/emulators/neural_networks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from skopt.space import Real
from torch import nn

from autoemulate.emulators.neural_networks.neural_networks import register
from autoemulate.emulators.neural_networks.neural_networks import TorchModule
from autoemulate.emulators.neural_networks.base import TorchModule


@register("mlp")
class MLPModule(TorchModule):
"""Multi-layer perceptron module for NeuralNetRegressor"""

def __init__(
self,
input_size: int = None,
Expand All @@ -34,28 +34,24 @@ def __init__(
self.model = nn.Sequential(*modules)

def get_grid_params(self, search_type: str = "random"):
param_space_random = {
"lr": loguniform(1e-4, 1e-2),
"max_epochs": [10, 20, 30],
"module__hidden_sizes": [
(50,),
(100,),
(100, 50),
(100, 100),
(200, 100),
],
}

param_space_bayes = {
"lr": Real(1e-4, 1e-2, prior="log-uniform"),
"max_epochs": Integer(10, 30),
}

match search_type:
case "random":
param_space = param_space_random
param_space = {
"lr": loguniform(1e-4, 1e-2),
"max_epochs": [10, 20, 30],
"module__hidden_sizes": [
(50,),
(100,),
(100, 50),
(100, 100),
(200, 100),
],
}
case "bayes":
param_space = param_space_bayes
param_space = {
"lr": Real(1e-4, 1e-2, prior="log-uniform"),
"max_epochs": Integer(10, 30),
}
case _:
raise ValueError(f"Invalid search type: {search_type}")

Expand Down

0 comments on commit e84b172

Please sign in to comment.