From d07331471506d7583ec0eda29e0d5533090ccf28 Mon Sep 17 00:00:00 2001 From: mastoffel Date: Mon, 14 Oct 2024 13:33:07 +0100 Subject: [PATCH 1/5] add attentive cnp --- autoemulate/emulators/__init__.py | 6 + .../conditional_neural_process_attn.py | 103 +++++++++ .../neural_networks/attn_cnp_module.py | 151 +++++++++++++ tests/models/test_attn_cnp.py | 207 ++++++++++++++++++ 4 files changed, 467 insertions(+) create mode 100644 autoemulate/emulators/conditional_neural_process_attn.py create mode 100644 autoemulate/emulators/neural_networks/attn_cnp_module.py create mode 100644 tests/models/test_attn_cnp.py diff --git a/autoemulate/emulators/__init__.py b/autoemulate/emulators/__init__.py index a07abfca..a0fd0257 100644 --- a/autoemulate/emulators/__init__.py +++ b/autoemulate/emulators/__init__.py @@ -1,5 +1,6 @@ from ..model_registry import ModelRegistry from .conditional_neural_process import ConditionalNeuralProcess +from .conditional_neural_process_attn import AttentiveConditionalNeuralProcess from .gaussian_process import GaussianProcess from .gaussian_process_mogp import GaussianProcessMOGP from .gaussian_process_mt import GaussianProcessMT @@ -38,6 +39,11 @@ # non-core models +model_registry.register_model( + AttentiveConditionalNeuralProcess().model_name, + AttentiveConditionalNeuralProcess, + is_core=True, +) model_registry.register_model( GaussianProcessMT().model_name, GaussianProcessMT, is_core=False ) diff --git a/autoemulate/emulators/conditional_neural_process_attn.py b/autoemulate/emulators/conditional_neural_process_attn.py new file mode 100644 index 00000000..a9bce866 --- /dev/null +++ b/autoemulate/emulators/conditional_neural_process_attn.py @@ -0,0 +1,103 @@ +import warnings + +import numpy as np +import torch +from scipy.stats import loguniform +from sklearn.base import BaseEstimator +from sklearn.base import RegressorMixin +from sklearn.preprocessing._data import _handle_zeros_in_scale +from sklearn.utils.validation import check_array +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import check_X_y +from skorch import NeuralNetRegressor +from skorch.callbacks import EarlyStopping +from skorch.callbacks import GradientNormClipping +from skorch.callbacks import LRScheduler +from torch import nn + +from autoemulate.emulators.conditional_neural_process import ConditionalNeuralProcess +from autoemulate.emulators.neural_networks.attn_cnp_module import AttnCNPModule +from autoemulate.emulators.neural_networks.cnp_module import CNPModule +from autoemulate.emulators.neural_networks.datasets import cnp_collate_fn +from autoemulate.emulators.neural_networks.datasets import CNPDataset +from autoemulate.emulators.neural_networks.losses import CNPLoss +from autoemulate.utils import set_random_seed + + +class AttentiveConditionalNeuralProcess(ConditionalNeuralProcess): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def fit(self, X, y): + X, y = check_X_y( + X, + y, + multi_output=True, + dtype=np.float32, + copy=True, + ensure_2d=True, + # ensure_min_samples=self.n_episode, + y_numeric=True, + ) + # y also needs to be float32 and 2d + y = y.astype(np.float32) + self.y_dim_ = y.ndim + if len(y.shape) == 1: + y = y.reshape(-1, 1) + + self.input_dim_ = X.shape[1] + self.output_dim_ = y.shape[1] + + # Normalize target value + # the zero handler is from sklearn + if self.normalize_y: + self._y_train_mean = np.mean(y, axis=0) + self._y_train_std = _handle_zeros_in_scale(np.std(y, axis=0), copy=False) + + # Remove mean and make unit variance + y = (y - self._y_train_mean) / self._y_train_std + + if self.random_state is not None: + set_random_seed(self.random_state) + + self.model_ = NeuralNetRegressor( + AttnCNPModule, + module__input_dim=self.input_dim_, + module__output_dim=self.output_dim_, + module__hidden_dim=self.hidden_dim, + module__latent_dim=self.latent_dim, + module__hidden_layers_enc=self.hidden_layers_enc, + module__hidden_layers_dec=self.hidden_layers_dec, + module__activation=self.activation, + dataset__min_context_points=self.min_context_points, + dataset__max_context_points=self.max_context_points, + dataset__n_episode=self.n_episode, + max_epochs=self.max_epochs, + lr=self.lr, + batch_size=self.batch_size, + optimizer=self.optimizer, + device=self.device, + dataset=CNPDataset, # special dataset to sample context and target sets + criterion=CNPLoss, + iterator_train__collate_fn=cnp_collate_fn, # special collate to different n in episodes + iterator_valid__collate_fn=cnp_collate_fn, + callbacks=[ + ("early_stopping", EarlyStopping(patience=10)), + ( + "lr_scheduler", + LRScheduler(policy="ReduceLROnPlateau", patience=5, factor=0.5), + ), + ("grad_norm", GradientNormClipping(gradient_clip_value=1.0)), + ], + # train_split=None, + verbose=0, + ) + self.model_.fit(X, y) + self.X_train_ = X + self.y_train_ = y + self.n_features_in_ = X.shape[1] + return self + + @property + def model_name(self): + return "AttentiveConditionalNeuralProcess" diff --git a/autoemulate/emulators/neural_networks/attn_cnp_module.py b/autoemulate/emulators/neural_networks/attn_cnp_module.py new file mode 100644 index 00000000..0b16cc74 --- /dev/null +++ b/autoemulate/emulators/neural_networks/attn_cnp_module.py @@ -0,0 +1,151 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.stats import loguniform +from skopt.space import Categorical +from skopt.space import Real + + +class Encoder(nn.Module): + """ + Deterministic encoder for conditional neural process model. + """ + + def __init__( + self, + input_dim, + output_dim, + hidden_dim, + latent_dim, + hidden_layers_enc, + activation, + context_mask=None, + ): + super().__init__() + layers = [nn.Linear(input_dim + output_dim, hidden_dim), activation()] + for _ in range(hidden_layers_enc): + layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()]) + layers.append(nn.Linear(hidden_dim, latent_dim)) + self.net = nn.Sequential(*layers) + + self.x_encoder = nn.Linear(input_dim, latent_dim) + + self.crossattn = nn.MultiheadAttention( + embed_dim=latent_dim, num_heads=4, batch_first=True + ) + + def forward(self, x_context, y_context, x_target, context_mask=None): + """ + Encode context + + Parameters + ---------- + x_context: (batch_size, n_context_points, input_dim) + y_context: (batch_size, n_context_points, output_dim) + context_mask: (batch_size, n_context_points) + + Returns + ------- + r: (batch_size, n_points, latent_dim) + """ + # context self attention + x = torch.cat([x_context, y_context], dim=-1) + r = self.net(x) + # q, k, v + x_target_enc = self.x_encoder(x_target) + x_context_enc = self.x_encoder(x_context) + if context_mask is not None: + r, _ = self.crossattn( + x_target_enc, + x_context_enc, + r, + need_weights=False, + key_padding_mask=context_mask, + ) + else: + r, _ = self.crossattn(x_target_enc, x_context_enc, r, need_weights=False) + return r + + +class Decoder(nn.Module): + def __init__( + self, + input_dim, + latent_dim, + hidden_dim, + output_dim, + hidden_layers_dec, + activation, + ): + super().__init__() + layers = [nn.Linear(latent_dim + input_dim, hidden_dim), activation()] + for _ in range(hidden_layers_dec): + layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()]) + self.net = nn.Sequential(*layers) + self.mean_head = nn.Linear(hidden_dim, output_dim) + self.logvar_head = nn.Linear(hidden_dim, output_dim) + + def forward(self, r, x_target): + """ + Decode using representation r and target points x_target + + Parameters + ---------- + r: (batch_size, n_points, latent_dim) + x_target: (batch_size, n_points, input_dim) + + Returns + ------- + mean: (batch_size, n_points, output_dim) + logvar: (batch_size, n_points, output_dim) + """ + x = torch.cat([r, x_target], dim=-1) + hidden = self.net(x) + mean = self.mean_head(hidden) + logvar = self.logvar_head(hidden) + + return mean, logvar + + +class AttnCNPModule(nn.Module): + def __init__( + self, + input_dim, + output_dim, + hidden_dim, + latent_dim, + hidden_layers_enc, + hidden_layers_dec, + activation=nn.ReLU, + ): + super().__init__() + self.encoder = Encoder( + input_dim, output_dim, hidden_dim, latent_dim, hidden_layers_enc, activation + ) + self.decoder = Decoder( + input_dim, latent_dim, hidden_dim, output_dim, hidden_layers_dec, activation + ) + + def forward(self, X_context, y_context, X_target=None, context_mask=None): + """ + + Parameters + ---------- + X_context: (batch_size, n_context_points, input_dim) + y_context: (batch_size, n_context_points, output_dim) + X_target: (batch_size, n_target_points, input_dim) + context_mask: (batch_size, n_context_points), currently unused, + as we pad with 0's and don't have attention, layernorm yet. + + Returns + ------- + mean: (batch_size, n_points, output_dim) + logvar: (batch_size, n_points, output_dim) + """ + # inverse context_mask + if context_mask is not None: + context_mask = ~context_mask + r = self.encoder(X_context, y_context, X_target) + mean, logvar = self.decoder(r, X_target) + return mean, logvar diff --git a/tests/models/test_attn_cnp.py b/tests/models/test_attn_cnp.py new file mode 100644 index 00000000..2a76335e --- /dev/null +++ b/tests/models/test_attn_cnp.py @@ -0,0 +1,207 @@ +import pytest +import torch +import torch.nn as nn + +from autoemulate.emulators.neural_networks.attn_cnp_module import AttnCNPModule +from autoemulate.emulators.neural_networks.attn_cnp_module import Decoder +from autoemulate.emulators.neural_networks.attn_cnp_module import Encoder + + +# encoder ---------------------------- +@pytest.fixture +def encoder(): + input_dim = 3 + output_dim = 2 + hidden_dim = 64 + latent_dim = 32 + hidden_layers_enc = 3 + activation = nn.ReLU + return Encoder( + input_dim, output_dim, hidden_dim, latent_dim, hidden_layers_enc, activation + ) + + +def test_encoder_initialization(encoder): + assert isinstance(encoder, nn.Module) + assert isinstance(encoder.net, nn.Sequential) + + +def test_encoder_forward_shape(encoder): + batch_size = 10 + n_points = 5 + n_target = 3 + input_dim = 3 + output_dim = 2 + + x_context = torch.randn(batch_size, n_points, input_dim) + y_context = torch.randn(batch_size, n_points, output_dim) + x_target = torch.randn(batch_size, n_target, input_dim) + + r = encoder(x_context, y_context, x_target) + + assert r.shape == (batch_size, n_target, 32) + + +def test_encoder_forward_deterministic(encoder): + batch_size = 10 + n_points = 5 + input_dim = 3 + output_dim = 2 + n_target = 3 + + x_context = torch.randn(batch_size, n_points, input_dim) + y_context = torch.randn(batch_size, n_points, output_dim) + x_target = torch.randn(batch_size, n_target, input_dim) + + r1 = encoder(x_context, y_context, x_target) + r2 = encoder(x_context, y_context, x_target) + + assert torch.allclose(r1, r2) + + +def test_encoder_different_batch_sizes(encoder): + input_dim = 3 + output_dim = 2 + + batch_sizes = [1, 5, 10] + n_points = 5 + n_target = 3 + + for batch_size in batch_sizes: + x_context = torch.randn(batch_size, n_points, input_dim) + y_context = torch.randn(batch_size, n_points, output_dim) + x_target = torch.randn(batch_size, n_target, input_dim) + + r = encoder(x_context, y_context, x_target) + assert r.shape == (batch_size, n_target, 32) + + +def test_encoder_mask(encoder): + input_dim = 5 + output_dim = 3 + hidden_dim = 64 + latent_dim = 32 + hidden_layers_enc = 2 + activation = torch.nn.ReLU + + encoder = Encoder( + input_dim, output_dim, hidden_dim, latent_dim, hidden_layers_enc, activation + ) + + batch_size = 2 + n_context_points = 10 + n_target_points = 8 + + x_context = torch.randn(batch_size, n_context_points, input_dim) + y_context = torch.randn(batch_size, n_context_points, output_dim) + x_target = torch.randn(batch_size, n_target_points, input_dim) + + context_mask = torch.ones(batch_size, n_context_points, dtype=torch.bool) + context_mask[:, -3:] = False + + output_with_mask = encoder(x_context, y_context, x_target, context_mask) + output_without_mask = encoder(x_context, y_context, x_target) + + assert not torch.allclose( + output_with_mask, output_without_mask + ), "Mask doesn't seem to affect the output" + + assert output_with_mask.shape == (batch_size, n_target_points, latent_dim) + assert output_without_mask.shape == (batch_size, n_target_points, latent_dim) + + +# decoder ---------------------------- +@pytest.fixture +def decoder(): + input_dim = 2 + latent_dim = 64 + hidden_dim = 128 + output_dim = 1 + hidden_layers_dec = 5 + activation = nn.ReLU + return Decoder( + input_dim, latent_dim, hidden_dim, output_dim, hidden_layers_dec, activation + ) + + +def test_decoder_initialization(decoder): + assert isinstance(decoder, nn.Module) + assert isinstance(decoder.net, nn.Sequential) + assert isinstance(decoder.mean_head, nn.Linear) + assert isinstance(decoder.logvar_head, nn.Linear) + + +def test_decoder_forward_shape(decoder): + batch_size, n_points, input_dim = 10, 5, 2 + latent_dim = 64 + + r = torch.randn(batch_size, n_points, latent_dim) + x_target = torch.randn(batch_size, n_points, input_dim) + + mean, logvar = decoder(r, x_target) + + assert mean.shape == (batch_size, n_points, 1) + assert logvar.shape == (batch_size, n_points, 1) + + +def test_decoder_different_batch_sizes(decoder): + latent_dim = 64 + input_dim = 2 + + for batch_size in [1, 10, 100]: + for n_points in [1, 5, 20]: + r = torch.randn(batch_size, n_points, latent_dim) + x_target = torch.randn(batch_size, n_points, input_dim) + + mean, logvar = decoder(r, x_target) + + assert mean.shape == (batch_size, n_points, 1) + assert logvar.shape == (batch_size, n_points, 1) + + +# attn cnp ---------------------------- +@pytest.fixture +def attn_cnp_module(): + return AttnCNPModule( + input_dim=2, + output_dim=1, + hidden_dim=32, + latent_dim=64, + hidden_layers_enc=2, + hidden_layers_dec=2, + activation=nn.ReLU, + ) + + +def test_attn_cnp_module_initialization(attn_cnp_module): + assert isinstance(attn_cnp_module, AttnCNPModule) + assert isinstance(attn_cnp_module.encoder, Encoder) + assert isinstance(attn_cnp_module.decoder, Decoder) + + +def test_attn_cnp_module_forward_shape(attn_cnp_module): + n_points = 16 + b, n, dx = 32, n_points, 2 + dy = 1 + X_context = torch.randn(b, n_points, dx) + y_context = torch.randn(b, n_points, dy) + X_target = torch.randn(b, n, dx) + + mean, logvar = attn_cnp_module(X_context, y_context, X_target) + + assert mean.shape == (b, n, dy) + assert logvar.shape == (b, n, dy) + + +def test_attn_cnp_module_forward_shape_2d(attn_cnp_module): + b, n, dx = 32, 24, 2 + dy = 2 + X = torch.randn(b, n, dx) + y = torch.randn(b, n, dy) + # re-initialise with 2 output dims + attn_cnp_module = AttnCNPModule( + input_dim=2, output_dim=2, hidden_dim=32, latent_dim=64, n_context_points=16 + ) + mean, logvar = attn_cnp_module(X, y) + assert mean.shape == (b, n, dy) + assert logvar.shape == (b, n, dy) From 17b88260ae25118173f986693e78a4fdb6c3934e Mon Sep 17 00:00:00 2001 From: mastoffel Date: Tue, 15 Oct 2024 10:10:18 +0100 Subject: [PATCH 2/5] fix tests --- tests/models/test_attn_cnp.py | 37 +++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/models/test_attn_cnp.py b/tests/models/test_attn_cnp.py index 2a76335e..07efaa75 100644 --- a/tests/models/test_attn_cnp.py +++ b/tests/models/test_attn_cnp.py @@ -2,9 +2,9 @@ import torch import torch.nn as nn -from autoemulate.emulators.neural_networks.attn_cnp_module import AttnCNPModule -from autoemulate.emulators.neural_networks.attn_cnp_module import Decoder -from autoemulate.emulators.neural_networks.attn_cnp_module import Encoder +from autoemulate.emulators.neural_networks.cnp_module_attn import AttnCNPModule +from autoemulate.emulators.neural_networks.cnp_module_attn import Decoder +from autoemulate.emulators.neural_networks.cnp_module_attn import Encoder # encoder ---------------------------- @@ -173,6 +173,19 @@ def attn_cnp_module(): ) +@pytest.fixture +def attn_cnp_module_2d(): + return AttnCNPModule( + input_dim=2, + output_dim=2, + hidden_dim=32, + latent_dim=64, + hidden_layers_enc=2, + hidden_layers_dec=2, + activation=nn.ReLU, + ) + + def test_attn_cnp_module_initialization(attn_cnp_module): assert isinstance(attn_cnp_module, AttnCNPModule) assert isinstance(attn_cnp_module.encoder, Encoder) @@ -193,15 +206,15 @@ def test_attn_cnp_module_forward_shape(attn_cnp_module): assert logvar.shape == (b, n, dy) -def test_attn_cnp_module_forward_shape_2d(attn_cnp_module): - b, n, dx = 32, 24, 2 +def test_attn_cnp_module_forward_shape_2d(attn_cnp_module_2d): + n_points = 16 + b, n, dx = 32, n_points, 2 dy = 2 - X = torch.randn(b, n, dx) - y = torch.randn(b, n, dy) - # re-initialise with 2 output dims - attn_cnp_module = AttnCNPModule( - input_dim=2, output_dim=2, hidden_dim=32, latent_dim=64, n_context_points=16 - ) - mean, logvar = attn_cnp_module(X, y) + X_context = torch.randn(b, n_points, dx) + y_context = torch.randn(b, n_points, dy) + X_target = torch.randn(b, n, dx) + + mean, logvar = attn_cnp_module_2d(X_context, y_context, X_target) + assert mean.shape == (b, n, dy) assert logvar.shape == (b, n, dy) From 53507a9f6f1f6ada19e655e2a4de55fe53cce7f2 Mon Sep 17 00:00:00 2001 From: mastoffel Date: Tue, 15 Oct 2024 11:13:29 +0100 Subject: [PATCH 3/5] refactor attn cnp --- autoemulate/emulators/__init__.py | 2 +- .../emulators/conditional_neural_process.py | 7 +- .../conditional_neural_process_attn.py | 134 ++++++------------ ...{attn_cnp_module.py => cnp_module_attn.py} | 0 tests/models/test_attn_cnp.py | 24 ++++ tests/test_estimators.py | 4 + 6 files changed, 73 insertions(+), 98 deletions(-) rename autoemulate/emulators/neural_networks/{attn_cnp_module.py => cnp_module_attn.py} (100%) diff --git a/autoemulate/emulators/__init__.py b/autoemulate/emulators/__init__.py index a0fd0257..a3d83d41 100644 --- a/autoemulate/emulators/__init__.py +++ b/autoemulate/emulators/__init__.py @@ -42,7 +42,7 @@ model_registry.register_model( AttentiveConditionalNeuralProcess().model_name, AttentiveConditionalNeuralProcess, - is_core=True, + is_core=False, ) model_registry.register_model( GaussianProcessMT().model_name, GaussianProcessMT, is_core=False diff --git a/autoemulate/emulators/conditional_neural_process.py b/autoemulate/emulators/conditional_neural_process.py index 2f7d72d8..771c4add 100644 --- a/autoemulate/emulators/conditional_neural_process.py +++ b/autoemulate/emulators/conditional_neural_process.py @@ -16,6 +16,7 @@ from torch import nn from autoemulate.emulators.neural_networks.cnp_module import CNPModule +from autoemulate.emulators.neural_networks.cnp_module_attn import AttnCNPModule from autoemulate.emulators.neural_networks.datasets import cnp_collate_fn from autoemulate.emulators.neural_networks.datasets import CNPDataset from autoemulate.emulators.neural_networks.losses import CNPLoss @@ -140,9 +141,6 @@ def __init__( self.activation = activation self.optimizer = optimizer self.normalize_y = normalize_y - if attention: - warnings.warn("Attention is not implemented yet, setting to False.") - attention = False self.attention = attention self.device = device self.random_state = random_state @@ -181,8 +179,9 @@ def fit(self, X, y): if self.random_state is not None: set_random_seed(self.random_state) + module = CNPModule if not self.attention else AttnCNPModule self.model_ = NeuralNetRegressor( - CNPModule, + module, module__input_dim=self.input_dim_, module__output_dim=self.output_dim_, module__hidden_dim=self.hidden_dim, diff --git a/autoemulate/emulators/conditional_neural_process_attn.py b/autoemulate/emulators/conditional_neural_process_attn.py index a9bce866..059a5b1e 100644 --- a/autoemulate/emulators/conditional_neural_process_attn.py +++ b/autoemulate/emulators/conditional_neural_process_attn.py @@ -1,103 +1,51 @@ -import warnings - -import numpy as np import torch -from scipy.stats import loguniform +import torch.nn as nn from sklearn.base import BaseEstimator from sklearn.base import RegressorMixin -from sklearn.preprocessing._data import _handle_zeros_in_scale -from sklearn.utils.validation import check_array -from sklearn.utils.validation import check_is_fitted -from sklearn.utils.validation import check_X_y -from skorch import NeuralNetRegressor -from skorch.callbacks import EarlyStopping -from skorch.callbacks import GradientNormClipping -from skorch.callbacks import LRScheduler -from torch import nn from autoemulate.emulators.conditional_neural_process import ConditionalNeuralProcess -from autoemulate.emulators.neural_networks.attn_cnp_module import AttnCNPModule -from autoemulate.emulators.neural_networks.cnp_module import CNPModule -from autoemulate.emulators.neural_networks.datasets import cnp_collate_fn -from autoemulate.emulators.neural_networks.datasets import CNPDataset -from autoemulate.emulators.neural_networks.losses import CNPLoss from autoemulate.utils import set_random_seed class AttentiveConditionalNeuralProcess(ConditionalNeuralProcess): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def fit(self, X, y): - X, y = check_X_y( - X, - y, - multi_output=True, - dtype=np.float32, - copy=True, - ensure_2d=True, - # ensure_min_samples=self.n_episode, - y_numeric=True, + def __init__( + self, + # architecture + hidden_dim=64, + latent_dim=64, + hidden_layers_enc=3, + hidden_layers_dec=3, + # data per episode + min_context_points=3, + max_context_points=10, + n_episode=32, + # training + max_epochs=100, + lr=5e-3, + batch_size=16, + activation=nn.ReLU, + optimizer=torch.optim.AdamW, + normalize_y=True, + # misc + device="cpu", + random_state=None, + attention=True, + ): + super().__init__( + hidden_dim=hidden_dim, + latent_dim=latent_dim, + hidden_layers_enc=hidden_layers_enc, + hidden_layers_dec=hidden_layers_dec, + min_context_points=min_context_points, + max_context_points=max_context_points, + n_episode=n_episode, + max_epochs=max_epochs, + lr=lr, + batch_size=batch_size, + activation=activation, + optimizer=optimizer, + normalize_y=normalize_y, + device=device, + random_state=random_state, + attention=attention, ) - # y also needs to be float32 and 2d - y = y.astype(np.float32) - self.y_dim_ = y.ndim - if len(y.shape) == 1: - y = y.reshape(-1, 1) - - self.input_dim_ = X.shape[1] - self.output_dim_ = y.shape[1] - - # Normalize target value - # the zero handler is from sklearn - if self.normalize_y: - self._y_train_mean = np.mean(y, axis=0) - self._y_train_std = _handle_zeros_in_scale(np.std(y, axis=0), copy=False) - - # Remove mean and make unit variance - y = (y - self._y_train_mean) / self._y_train_std - - if self.random_state is not None: - set_random_seed(self.random_state) - - self.model_ = NeuralNetRegressor( - AttnCNPModule, - module__input_dim=self.input_dim_, - module__output_dim=self.output_dim_, - module__hidden_dim=self.hidden_dim, - module__latent_dim=self.latent_dim, - module__hidden_layers_enc=self.hidden_layers_enc, - module__hidden_layers_dec=self.hidden_layers_dec, - module__activation=self.activation, - dataset__min_context_points=self.min_context_points, - dataset__max_context_points=self.max_context_points, - dataset__n_episode=self.n_episode, - max_epochs=self.max_epochs, - lr=self.lr, - batch_size=self.batch_size, - optimizer=self.optimizer, - device=self.device, - dataset=CNPDataset, # special dataset to sample context and target sets - criterion=CNPLoss, - iterator_train__collate_fn=cnp_collate_fn, # special collate to different n in episodes - iterator_valid__collate_fn=cnp_collate_fn, - callbacks=[ - ("early_stopping", EarlyStopping(patience=10)), - ( - "lr_scheduler", - LRScheduler(policy="ReduceLROnPlateau", patience=5, factor=0.5), - ), - ("grad_norm", GradientNormClipping(gradient_clip_value=1.0)), - ], - # train_split=None, - verbose=0, - ) - self.model_.fit(X, y) - self.X_train_ = X - self.y_train_ = y - self.n_features_in_ = X.shape[1] - return self - - @property - def model_name(self): - return "AttentiveConditionalNeuralProcess" diff --git a/autoemulate/emulators/neural_networks/attn_cnp_module.py b/autoemulate/emulators/neural_networks/cnp_module_attn.py similarity index 100% rename from autoemulate/emulators/neural_networks/attn_cnp_module.py rename to autoemulate/emulators/neural_networks/cnp_module_attn.py diff --git a/tests/models/test_attn_cnp.py b/tests/models/test_attn_cnp.py index 07efaa75..43cdbee1 100644 --- a/tests/models/test_attn_cnp.py +++ b/tests/models/test_attn_cnp.py @@ -1,7 +1,12 @@ import pytest import torch import torch.nn as nn +from sklearn.datasets import make_regression +from sklearn.model_selection import RandomizedSearchCV +from autoemulate.emulators.conditional_neural_process_attn import ( + AttentiveConditionalNeuralProcess, +) from autoemulate.emulators.neural_networks.cnp_module_attn import AttnCNPModule from autoemulate.emulators.neural_networks.cnp_module_attn import Decoder from autoemulate.emulators.neural_networks.cnp_module_attn import Encoder @@ -218,3 +223,22 @@ def test_attn_cnp_module_forward_shape_2d(attn_cnp_module_2d): assert mean.shape == (b, n, dy) assert logvar.shape == (b, n, dy) + + +# test whether param search works +def test_attn_cnp_param_search(): + X, y = make_regression(n_samples=30, n_features=5, n_targets=2, random_state=0) + param_grid = { + "hidden_dim": [16, 32], + "latent_dim": [16, 32], + } + + mod = AttentiveConditionalNeuralProcess() + + grid_search = RandomizedSearchCV( + estimator=mod, param_distributions=param_grid, cv=3, n_iter=3 + ) + grid_search.fit(X, y) + + assert grid_search.best_score_ > 0.3 + assert grid_search.best_params_["hidden_dim"] in [16, 32] diff --git a/tests/test_estimators.py b/tests/test_estimators.py index bd38c08c..ec430523 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -13,7 +13,9 @@ from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.estimator_checks import set_random_state +from autoemulate.emulators import AttentiveConditionalNeuralProcess from autoemulate.emulators import ConditionalNeuralProcess +from autoemulate.emulators import GaussianProcess from autoemulate.emulators import GaussianProcessMOGP from autoemulate.emulators import GaussianProcessMT from autoemulate.emulators import GaussianProcessSklearn @@ -38,6 +40,8 @@ LightGBM(), ConditionalNeuralProcess(random_state=42), GaussianProcessMT(random_state=42), + AttentiveConditionalNeuralProcess(random_state=42), + GaussianProcess(random_state=42), ] ) def test_check_estimator(estimator, check): From 90b8183cee9a28ffe175cfc4b2ad522c730aecfe Mon Sep 17 00:00:00 2001 From: mastoffel Date: Tue, 15 Oct 2024 11:38:43 +0100 Subject: [PATCH 4/5] fix model retrieval bug --- autoemulate/compare.py | 3 ++- autoemulate/model_processing.py | 8 +++++--- autoemulate/model_registry.py | 4 ++-- tests/test_model_registry.py | 9 +++++++++ 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/autoemulate/compare.py b/autoemulate/compare.py index dee40c3e..728d9291 100644 --- a/autoemulate/compare.py +++ b/autoemulate/compare.py @@ -108,9 +108,10 @@ def setup( self.X, test_size=self.test_set_size, random_state=42 ) self.model_names = self.model_registry.get_model_names(models, is_core=True) + print(self.model_registry.get_model_names(models)) self.models = _process_models( model_registry=self.model_registry, - models=list(self.model_names.keys()), + model_names=list(self.model_names.keys()), y=self.y, scale=scale, scaler=scaler, diff --git a/autoemulate/model_processing.py b/autoemulate/model_processing.py index 65da5850..9e399359 100644 --- a/autoemulate/model_processing.py +++ b/autoemulate/model_processing.py @@ -65,14 +65,16 @@ def _wrap_models_in_pipeline(models, scale, scaler, reduce_dim, dim_reducer): return models_piped -def _process_models(model_registry, models, y, scale, scaler, reduce_dim, dim_reducer): +def _process_models( + model_registry, model_names, y, scale, scaler, reduce_dim, dim_reducer +): """Get and process models. Parameters ---------- model_registry : ModelRegistry An instance of the ModelRegistry class. - models : list + model_names : list List of model names. y : array-like, shape (n_samples, n_outputs) Simulation output. @@ -86,7 +88,7 @@ def _process_models(model_registry, models, y, scale, scaler, reduce_dim, dim_re models : list List of model instances. """ - models = model_registry.get_models(models) + models = model_registry.get_models(model_names) models_multi = _turn_models_into_multioutput(models, y) models_scaled = _wrap_models_in_pipeline( models_multi, scale, scaler, reduce_dim, dim_reducer diff --git a/autoemulate/model_registry.py b/autoemulate/model_registry.py index fce0c7e9..3e5e0559 100644 --- a/autoemulate/model_registry.py +++ b/autoemulate/model_registry.py @@ -23,7 +23,7 @@ def get_model_names(self, models=None, is_core=False): models : str or list of str The name(s) of the model(s) to get long and short names for. is_core : bool - Whether to return only core model names. + Whether to return only core model names in case `models` is None. Returns ------- @@ -61,7 +61,7 @@ def get_model_names(self, models=None, is_core=False): k: v for k, v in model_names.items() if k in models or v in models } - if is_core: + if models is None and is_core: model_names = { k: v for k, v in model_names.items() if k in self.core_model_names } diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index 5c6e5df9..9c975dbe 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -126,3 +126,12 @@ def test_get_models_mix(model_registry): assert len(models) == 2 assert models[0].__class__ == RadialBasisFunctions assert models[1].__class__ == ConditionalNeuralProcess + + +def test_get_noncore_model_w_is_core_true(model_registry): + # this should retrieve gps, a non-core model, despite the is_core=True flag + # only needed in compare.py + model_names = model_registry.get_model_names(models="gps", is_core=True) + assert isinstance(model_names, dict) + assert len(model_names) == 1 + assert model_names["GaussianProcessSklearn"] == "gps" From c4c1a09db5ce4b2b8845513f7ce846e01ce9440b Mon Sep 17 00:00:00 2001 From: mastoffel Date: Tue, 15 Oct 2024 11:43:34 +0100 Subject: [PATCH 5/5] remove print --- autoemulate/compare.py | 1 - 1 file changed, 1 deletion(-) diff --git a/autoemulate/compare.py b/autoemulate/compare.py index 728d9291..7307ef03 100644 --- a/autoemulate/compare.py +++ b/autoemulate/compare.py @@ -108,7 +108,6 @@ def setup( self.X, test_size=self.test_set_size, random_state=42 ) self.model_names = self.model_registry.get_model_names(models, is_core=True) - print(self.model_registry.get_model_names(models)) self.models = _process_models( model_registry=self.model_registry, model_names=list(self.model_names.keys()),