Skip to content

Commit

Permalink
fix skorch neuralnet
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Nov 8, 2023
1 parent e741b23 commit 36d022b
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 11 deletions.
8 changes: 4 additions & 4 deletions autoemulate/emulators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from .base import Emulator
from .gaussian_process import GaussianProcess
from .gaussian_process_sk import GaussianProcessSk
from .neural_network import NeuralNetwork
from .neural_net_sk import NeuralNetSk
from .random_forest import RandomForest
from .radial_basis import RadialBasis
from .neural_net_pt import SkorchMLPRegressor
from .neural_net_torch import NeuralNetTorch

MODEL_REGISTRY = {
# "GaussianProcess": GaussianProcess,
"GaussianProcessSk": GaussianProcessSk,
"NeuralNetwork": NeuralNetwork,
"NeuralNetSk": NeuralNetSk,
"RandomForest": RandomForest,
"RadialBasis": RadialBasis,
"SkorchMLPRegressor": SkorchMLPRegressor,
# "NeuralNetTorch": NeuralNetTorch,
}
2 changes: 1 addition & 1 deletion autoemulate/emulators/gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def fit(self, X, y):
X, y = check_X_y(X, y, multi_output=False, y_numeric=True)
self.n_features_in_ = X.shape[1]
self.model_ = mogp_emulator.GaussianProcess(X, y, nugget=self.nugget)
self.model_ = mogp_emulator.fit_GP_MAP(self.model_, n_tries=2)
self.model_ = mogp_emulator.fit_GP_MAP(self.model_, n_tries=15)
self.is_fitted_ = True
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted


class NeuralNetwork(BaseEstimator, RegressorMixin):
class NeuralNetSk(BaseEstimator, RegressorMixin):
"""Multi-layer perceptron Emulator.
Implements MLPRegressor from scikit-learn.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# experimental version of a PyTorch neural network emulator wrapped in Skorch
# to make it compatible with scikit-learn. Works with cross_validate and GridSearchCV,
# but doesn't pass tests, because we're subclassing

import torch
import numpy as np
import skorch
from torch import nn
from skorch import NeuralNetRegressor
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted


class InputShapeSetter(skorch.callbacks.Callback):
Expand Down Expand Up @@ -40,7 +42,7 @@ def forward(self, X):


# Step 2: Create the Skorch wrapper for the NeuralNetRegressor
class SkorchMLPRegressor(NeuralNetRegressor):
class NeuralNetTorch(NeuralNetRegressor):
def __init__(
self,
module=MLPModule,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
NeuralNetwork,
GaussianProcess,
RadialBasis,
SkorchMLPRegressor,
)
from functools import partial


@parametrize_with_checks(
[ # GaussianProcess(),
RandomForest(random_state=42),
GaussianProcessSk(random_state=1337),
NeuralNetwork(random_state=13),
# GaussianProcessSk(random_state=1337),
# NeuralNetwork(random_state=13),
# RadialBasis(),
SkorchMLPRegressor(random_state=42),
]
)
def test_check_estimator(estimator, check):
Expand Down

0 comments on commit 36d022b

Please sign in to comment.