Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanlimy committed Feb 26, 2024
1 parent d40d9c0 commit be2479c
Showing 1 changed file with 6 additions and 28 deletions.
34 changes: 6 additions & 28 deletions autoemulate/emulators/neural_net_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,6 @@
from autoemulate.utils import set_random_seed


class InputShapeSetter(Callback):
"""Callback to set input and output layer sizes dynamically."""

def on_train_begin(
self,
net: NeuralNetRegressor,
X: torch.Tensor | np.ndarray = None,
y: torch.Tensor | np.ndarray = None,
**kwargs,
):
if hasattr(net, "n_features_in_") and net.n_features_in_ != X.shape[-1]:
raise ValueError(
f"Mismatch number of features, "
f"expected {net.n_features_in_}, received {X.shape[-1]}."
)
# if hasattr(net, "n_features_in_") and net.n_features_in_ != X.shape[-1]:
# raise ValueError(
# f"Mismatch number of features, "
# f"expected {net.n_features_in_}, received {X.shape[-1]}."
# )
# if not hasattr(net, "n_features_in_"):
# output_size = 1 if y.ndim == 1 else y.shape[1]
# net.set_params(
# module__input_size=X.shape[1], module__output_size=output_size
# )


class NeuralNetTorch(NeuralNetRegressor):
"""
Wrap PyTorch modules in Skorch to make them compatible with scikit-learn.
Expand All @@ -63,7 +36,7 @@ def __init__(
module__output_size: int = None,
optimizer__weight_decay: float = 0.0,
iterator_train__shuffle: bool = True,
callbacks: List[Callback] = [InputShapeSetter()],
callbacks: List[Callback] = None,
train_split: bool = False, # to run cross_validate without splitting the data
verbose: int = 0,
**kwargs,
Expand Down Expand Up @@ -168,6 +141,11 @@ def check_initialized(self, X: np.ndarray, y: np.ndarray):

def fit_loop(self, X, y=None, epochs=None, **fit_params):
X, y = self.check_data(X, y)
if hasattr(self, "n_features_in_") and self.n_features_in_ != X.shape[-1]:
raise ValueError(
f"Mismatch number of features, "
f"expected {self.n_features_in_}, received {X.shape[-1]}."
)
return super().fit_loop(X, y, epochs, **fit_params)

def partial_fit(self, X, y=None, classes=None, **fit_params):
Expand Down

0 comments on commit be2479c

Please sign in to comment.