Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention #259

Merged
merged 5 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def setup(
self.model_names = self.model_registry.get_model_names(models, is_core=True)
self.models = _process_models(
model_registry=self.model_registry,
models=list(self.model_names.keys()),
model_names=list(self.model_names.keys()),
y=self.y,
scale=scale,
scaler=scaler,
Expand Down
6 changes: 6 additions & 0 deletions autoemulate/emulators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..model_registry import ModelRegistry
from .conditional_neural_process import ConditionalNeuralProcess
from .conditional_neural_process_attn import AttentiveConditionalNeuralProcess
from .gaussian_process import GaussianProcess
from .gaussian_process_mogp import GaussianProcessMOGP
from .gaussian_process_mt import GaussianProcessMT
Expand Down Expand Up @@ -38,6 +39,11 @@


# non-core models
model_registry.register_model(
AttentiveConditionalNeuralProcess().model_name,
AttentiveConditionalNeuralProcess,
is_core=False,
)
model_registry.register_model(
GaussianProcessMT().model_name, GaussianProcessMT, is_core=False
)
Expand Down
7 changes: 3 additions & 4 deletions autoemulate/emulators/conditional_neural_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch import nn

from autoemulate.emulators.neural_networks.cnp_module import CNPModule
from autoemulate.emulators.neural_networks.cnp_module_attn import AttnCNPModule
from autoemulate.emulators.neural_networks.datasets import cnp_collate_fn
from autoemulate.emulators.neural_networks.datasets import CNPDataset
from autoemulate.emulators.neural_networks.losses import CNPLoss
Expand Down Expand Up @@ -140,9 +141,6 @@ def __init__(
self.activation = activation
self.optimizer = optimizer
self.normalize_y = normalize_y
if attention:
warnings.warn("Attention is not implemented yet, setting to False.")
attention = False
self.attention = attention
self.device = device
self.random_state = random_state
Expand Down Expand Up @@ -181,8 +179,9 @@ def fit(self, X, y):
if self.random_state is not None:
set_random_seed(self.random_state)

module = CNPModule if not self.attention else AttnCNPModule
self.model_ = NeuralNetRegressor(
CNPModule,
module,
module__input_dim=self.input_dim_,
module__output_dim=self.output_dim_,
module__hidden_dim=self.hidden_dim,
Expand Down
51 changes: 51 additions & 0 deletions autoemulate/emulators/conditional_neural_process_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn as nn
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin

from autoemulate.emulators.conditional_neural_process import ConditionalNeuralProcess
from autoemulate.utils import set_random_seed


class AttentiveConditionalNeuralProcess(ConditionalNeuralProcess):
def __init__(
self,
# architecture
hidden_dim=64,
latent_dim=64,
hidden_layers_enc=3,
hidden_layers_dec=3,
# data per episode
min_context_points=3,
max_context_points=10,
n_episode=32,
# training
max_epochs=100,
lr=5e-3,
batch_size=16,
activation=nn.ReLU,
optimizer=torch.optim.AdamW,
normalize_y=True,
# misc
device="cpu",
random_state=None,
attention=True,
):
super().__init__(
hidden_dim=hidden_dim,
latent_dim=latent_dim,
hidden_layers_enc=hidden_layers_enc,
hidden_layers_dec=hidden_layers_dec,
min_context_points=min_context_points,
max_context_points=max_context_points,
n_episode=n_episode,
max_epochs=max_epochs,
lr=lr,
batch_size=batch_size,
activation=activation,
optimizer=optimizer,
normalize_y=normalize_y,
device=device,
random_state=random_state,
attention=attention,
)
151 changes: 151 additions & 0 deletions autoemulate/emulators/neural_networks/cnp_module_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import loguniform
from skopt.space import Categorical
from skopt.space import Real


class Encoder(nn.Module):
"""
Deterministic encoder for conditional neural process model.
"""

def __init__(
self,
input_dim,
output_dim,
hidden_dim,
latent_dim,
hidden_layers_enc,
activation,
context_mask=None,
):
super().__init__()
layers = [nn.Linear(input_dim + output_dim, hidden_dim), activation()]
for _ in range(hidden_layers_enc):
layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()])
layers.append(nn.Linear(hidden_dim, latent_dim))
self.net = nn.Sequential(*layers)

self.x_encoder = nn.Linear(input_dim, latent_dim)

self.crossattn = nn.MultiheadAttention(
embed_dim=latent_dim, num_heads=4, batch_first=True
)

def forward(self, x_context, y_context, x_target, context_mask=None):
"""
Encode context

Parameters
----------
x_context: (batch_size, n_context_points, input_dim)
y_context: (batch_size, n_context_points, output_dim)
context_mask: (batch_size, n_context_points)

Returns
-------
r: (batch_size, n_points, latent_dim)
"""
# context self attention
x = torch.cat([x_context, y_context], dim=-1)
r = self.net(x)
# q, k, v
x_target_enc = self.x_encoder(x_target)
x_context_enc = self.x_encoder(x_context)
if context_mask is not None:
r, _ = self.crossattn(
x_target_enc,
x_context_enc,
r,
need_weights=False,
key_padding_mask=context_mask,
)
else:
r, _ = self.crossattn(x_target_enc, x_context_enc, r, need_weights=False)
return r


class Decoder(nn.Module):
def __init__(
self,
input_dim,
latent_dim,
hidden_dim,
output_dim,
hidden_layers_dec,
activation,
):
super().__init__()
layers = [nn.Linear(latent_dim + input_dim, hidden_dim), activation()]
for _ in range(hidden_layers_dec):
layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()])
self.net = nn.Sequential(*layers)
self.mean_head = nn.Linear(hidden_dim, output_dim)
self.logvar_head = nn.Linear(hidden_dim, output_dim)

def forward(self, r, x_target):
"""
Decode using representation r and target points x_target

Parameters
----------
r: (batch_size, n_points, latent_dim)
x_target: (batch_size, n_points, input_dim)

Returns
-------
mean: (batch_size, n_points, output_dim)
logvar: (batch_size, n_points, output_dim)
"""
x = torch.cat([r, x_target], dim=-1)
hidden = self.net(x)
mean = self.mean_head(hidden)
logvar = self.logvar_head(hidden)

return mean, logvar


class AttnCNPModule(nn.Module):
def __init__(
self,
input_dim,
output_dim,
hidden_dim,
latent_dim,
hidden_layers_enc,
hidden_layers_dec,
activation=nn.ReLU,
):
super().__init__()
self.encoder = Encoder(
input_dim, output_dim, hidden_dim, latent_dim, hidden_layers_enc, activation
)
self.decoder = Decoder(
input_dim, latent_dim, hidden_dim, output_dim, hidden_layers_dec, activation
)

def forward(self, X_context, y_context, X_target=None, context_mask=None):
"""

Parameters
----------
X_context: (batch_size, n_context_points, input_dim)
y_context: (batch_size, n_context_points, output_dim)
X_target: (batch_size, n_target_points, input_dim)
context_mask: (batch_size, n_context_points), currently unused,
as we pad with 0's and don't have attention, layernorm yet.

Returns
-------
mean: (batch_size, n_points, output_dim)
logvar: (batch_size, n_points, output_dim)
"""
# inverse context_mask
if context_mask is not None:
context_mask = ~context_mask
r = self.encoder(X_context, y_context, X_target)
mean, logvar = self.decoder(r, X_target)
return mean, logvar
8 changes: 5 additions & 3 deletions autoemulate/model_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,16 @@ def _wrap_models_in_pipeline(models, scale, scaler, reduce_dim, dim_reducer):
return models_piped


def _process_models(model_registry, models, y, scale, scaler, reduce_dim, dim_reducer):
def _process_models(
model_registry, model_names, y, scale, scaler, reduce_dim, dim_reducer
):
"""Get and process models.

Parameters
----------
model_registry : ModelRegistry
An instance of the ModelRegistry class.
models : list
model_names : list
List of model names.
y : array-like, shape (n_samples, n_outputs)
Simulation output.
Expand All @@ -86,7 +88,7 @@ def _process_models(model_registry, models, y, scale, scaler, reduce_dim, dim_re
models : list
List of model instances.
"""
models = model_registry.get_models(models)
models = model_registry.get_models(model_names)
models_multi = _turn_models_into_multioutput(models, y)
models_scaled = _wrap_models_in_pipeline(
models_multi, scale, scaler, reduce_dim, dim_reducer
Expand Down
4 changes: 2 additions & 2 deletions autoemulate/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_model_names(self, models=None, is_core=False):
models : str or list of str
The name(s) of the model(s) to get long and short names for.
is_core : bool
Whether to return only core model names.
Whether to return only core model names in case `models` is None.

Returns
-------
Expand Down Expand Up @@ -61,7 +61,7 @@ def get_model_names(self, models=None, is_core=False):
k: v for k, v in model_names.items() if k in models or v in models
}

if is_core:
if models is None and is_core:
model_names = {
k: v for k, v in model_names.items() if k in self.core_model_names
}
Expand Down
Loading
Loading