-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from aai-institute/feature/belnet
Feature: BelNet.
- Loading branch information
Showing
6 changed files
with
250 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
""" | ||
`continuity.operators.belnet` | ||
The BelNet architecture. | ||
""" | ||
|
||
import torch | ||
from typing import Optional | ||
from continuity.operators import Operator | ||
from continuity.operators.common import DeepResidualNetwork | ||
from continuity.data import DatasetShapes | ||
|
||
|
||
class BelNet(Operator): | ||
r""" | ||
The BelNet architecture is an extension of the DeepONet architecture that | ||
adds a learnable projection basis network to interpolate the sensor inputs. | ||
Therefore, it supports changing sensor positions, or in other terms, is | ||
*discretization invariant*. | ||
*Reference:* Z. Zhang et al. BelNet: basis enhanced learning, a mesh-free | ||
neural operator. Proceedings of the royal society A (2023). | ||
**Note:** In the paper, you can use Figure 6 for reference, but we swapped | ||
the notation of `x` and `y` to comply with the convention in Continuity, | ||
where `x` is the collocation points and `y` is the evaluation points. We | ||
also replace the single layer projection and construction networks by more | ||
expressive deep residual networks. | ||
Args: | ||
shapes: Shape variable of the dataset | ||
K: Number of basis functions | ||
N_1: Width of the projection basis network | ||
D_1: Depth of the projection basis network | ||
N_2: Width of the construction network | ||
D_2: Depth of the construction network | ||
a_x: Activation function of projection networks | ||
a_u: Activation function applied after the projection | ||
a_y: Activation function of the construction network | ||
""" | ||
|
||
def __init__( | ||
self, | ||
shapes: DatasetShapes, | ||
K: int = 4, | ||
N_1: int = 32, | ||
D_1: int = 1, | ||
N_2: int = 32, | ||
D_2: int = 1, | ||
a_x: Optional[torch.nn.Module] = None, | ||
a_u: Optional[torch.nn.Module] = None, | ||
a_y: Optional[torch.nn.Module] = None, | ||
): | ||
super().__init__() | ||
|
||
self.shapes = shapes | ||
self.K = K | ||
self.a_x = a_x or torch.nn.Tanh() | ||
self.a_u = a_u or torch.nn.LeakyReLU() | ||
self.a_y = a_y or torch.nn.Tanh() | ||
|
||
self.Nx = self.shapes.x.num * self.shapes.x.dim | ||
self.Nu = self.shapes.u.num * self.shapes.u.dim | ||
self.Kv = K * self.shapes.v.dim | ||
|
||
# K projection nets | ||
self.p = torch.nn.ModuleList( | ||
[ | ||
DeepResidualNetwork( | ||
input_size=self.Nx, | ||
output_size=self.Nu, | ||
width=N_1, | ||
depth=D_1, | ||
act=self.a_x, | ||
) | ||
for _ in range(K) | ||
] | ||
) | ||
|
||
# construction net | ||
self.q = DeepResidualNetwork( | ||
input_size=shapes.y.dim, | ||
output_size=self.Kv, | ||
width=N_2, | ||
depth=D_2, | ||
act=self.a_y, | ||
) | ||
|
||
def forward( | ||
self, x: torch.Tensor, u: torch.Tensor, y: torch.Tensor | ||
) -> torch.Tensor: | ||
"""Forward pass through the operator. | ||
Args: | ||
x: Sensor positions of shape (batch_size, #sensors, x_dim). | ||
u: Input function values of shape (batch_size, #sensors, u_dim) | ||
y: Evaluation coordinates of shape (batch_size, #evaluations, y_dim) | ||
Returns: | ||
Operator output (batch_size, #evaluations, v_dim) | ||
""" | ||
assert x.size(0) == u.size(0) == y.size(0) | ||
num_evaluations = y.size(1) | ||
|
||
# flatten inputs | ||
x = x.reshape(-1, self.Nx) | ||
u = u.reshape(-1, self.Nu) | ||
y = y.reshape(-1, self.shapes.y.dim) | ||
|
||
# build projection matrix | ||
P = torch.stack([p(x) for p in self.p], dim=1) | ||
assert P.shape[1:] == torch.Size([self.K, self.Nu]) | ||
|
||
# perform the projection | ||
aPu = self.a_u(torch.einsum("bkn,bn->bk", P, u)) | ||
assert aPu.shape[1:] == torch.Size([self.K]) | ||
|
||
# construction net | ||
Q = self.q(y) | ||
assert Q.shape[1:] == torch.Size([self.Kv]) | ||
|
||
# dot product | ||
Q = Q.reshape(-1, num_evaluations, self.K, self.shapes.v.dim) | ||
output = torch.einsum("bk,bckv->bcv", aPu, Q) | ||
assert output.shape[1:] == torch.Size([num_evaluations, self.shapes.v.dim]) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import torch | ||
import matplotlib.pyplot as plt | ||
import pytest | ||
|
||
from continuity.plotting import plot, plot_evaluation | ||
from torch.utils.data import DataLoader | ||
from continuity.operators import BelNet | ||
from continuity.data import OperatorDataset | ||
from continuity.data.sine import OperatorDataset, Sine | ||
from continuity.trainer import Trainer | ||
from continuity.operators.losses import MSELoss | ||
|
||
|
||
def test_belnet_shape(): | ||
x_dim = 2 | ||
u_dim = 3 | ||
y_dim = 5 | ||
v_dim = 7 | ||
n_sensors = 11 | ||
n_evals = 13 | ||
batch_size = 17 | ||
set_size = 19 | ||
|
||
dset = OperatorDataset( | ||
x=torch.rand((set_size, n_sensors, x_dim)), | ||
u=torch.rand((set_size, n_sensors, u_dim)), | ||
y=torch.rand((set_size, n_evals, y_dim)), | ||
v=torch.rand((set_size, n_evals, v_dim)), | ||
) | ||
|
||
model = BelNet(dset.shapes) | ||
|
||
x, u, y, v = dset[:batch_size] | ||
|
||
v_pred = model(x, u, y) | ||
|
||
assert v_pred.shape == v.shape | ||
|
||
y_other = torch.rand((batch_size, n_evals, y_dim)) | ||
v_other = torch.rand((batch_size, n_evals, v_dim)) | ||
|
||
v_other_pred = model(x, u, y_other) | ||
|
||
assert v_other_pred.shape == v_other.shape | ||
|
||
|
||
@pytest.mark.slow | ||
def test_belnet(): | ||
# Parameters | ||
num_sensors = 16 | ||
|
||
# Data set | ||
dataset = Sine(num_sensors, size=1) | ||
data_loader = DataLoader(dataset, batch_size=1, shuffle=True) | ||
|
||
# Operator | ||
operator = BelNet( | ||
dataset.shapes, | ||
) | ||
|
||
# Train self-supervised | ||
optimizer = torch.optim.Adam(operator.parameters(), lr=1e-3) | ||
trainer = Trainer(operator, optimizer) | ||
trainer.fit(data_loader, epochs=1000) | ||
|
||
# Plotting | ||
fig, ax = plt.subplots(1, 1) | ||
x, u, _, _ = dataset[0] | ||
plot(x, u, ax=ax) | ||
plot_evaluation(operator, x, u, ax=ax) | ||
fig.savefig(f"test_belnet.png") | ||
|
||
# Check solution | ||
x = x.unsqueeze(0) | ||
u = u.unsqueeze(0) | ||
assert MSELoss()(operator, x, u, x, u) < 1e-3 | ||
|
||
|
||
if __name__ == "__main__": | ||
test_belnet_shape() | ||
test_belnet() |