Skip to content

Commit

Permalink
Merge pull request #59 from aai-institute/feature/belnet
Browse files Browse the repository at this point in the history
Feature: BelNet.
  • Loading branch information
samuelburbulla authored Feb 23, 2024
2 parents 593ce5e + 6bf6eb2 commit 3854baf
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

- Move all content of `__init__.py` files to sub-modules.
- Add `Trainer` class to replace `operator.fit` method.

- Implement `BelNet`.

## 0.0.0 (2024-02-22)

Expand Down
1 change: 1 addition & 0 deletions docs/operators/architectures.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ alias:
Continuity implements the following neural operator architectures:

- [DeepONet](../../api/continuity/operators/deeponet/)
- [BelNet](../../api/continuity/operators/belnet/)
- [(Fourier) Neural Operator](../../api/continuity/operators/neuraloperator/)

and more to come...
2 changes: 2 additions & 0 deletions src/continuity/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from .operator import Operator
from .neuraloperator import NeuralOperator
from .deeponet import DeepONet
from .belnet import BelNet

__all__ = [
"Operator",
"DeepONet",
"NeuralOperator",
"DeepResidualNetwork",
"BelNet",
]
127 changes: 127 additions & 0 deletions src/continuity/operators/belnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
`continuity.operators.belnet`
The BelNet architecture.
"""

import torch
from typing import Optional
from continuity.operators import Operator
from continuity.operators.common import DeepResidualNetwork
from continuity.data import DatasetShapes


class BelNet(Operator):
r"""
The BelNet architecture is an extension of the DeepONet architecture that
adds a learnable projection basis network to interpolate the sensor inputs.
Therefore, it supports changing sensor positions, or in other terms, is
*discretization invariant*.
*Reference:* Z. Zhang et al. BelNet: basis enhanced learning, a mesh-free
neural operator. Proceedings of the royal society A (2023).
**Note:** In the paper, you can use Figure 6 for reference, but we swapped
the notation of `x` and `y` to comply with the convention in Continuity,
where `x` is the collocation points and `y` is the evaluation points. We
also replace the single layer projection and construction networks by more
expressive deep residual networks.
Args:
shapes: Shape variable of the dataset
K: Number of basis functions
N_1: Width of the projection basis network
D_1: Depth of the projection basis network
N_2: Width of the construction network
D_2: Depth of the construction network
a_x: Activation function of projection networks
a_u: Activation function applied after the projection
a_y: Activation function of the construction network
"""

def __init__(
self,
shapes: DatasetShapes,
K: int = 4,
N_1: int = 32,
D_1: int = 1,
N_2: int = 32,
D_2: int = 1,
a_x: Optional[torch.nn.Module] = None,
a_u: Optional[torch.nn.Module] = None,
a_y: Optional[torch.nn.Module] = None,
):
super().__init__()

self.shapes = shapes
self.K = K
self.a_x = a_x or torch.nn.Tanh()
self.a_u = a_u or torch.nn.LeakyReLU()
self.a_y = a_y or torch.nn.Tanh()

self.Nx = self.shapes.x.num * self.shapes.x.dim
self.Nu = self.shapes.u.num * self.shapes.u.dim
self.Kv = K * self.shapes.v.dim

# K projection nets
self.p = torch.nn.ModuleList(
[
DeepResidualNetwork(
input_size=self.Nx,
output_size=self.Nu,
width=N_1,
depth=D_1,
act=self.a_x,
)
for _ in range(K)
]
)

# construction net
self.q = DeepResidualNetwork(
input_size=shapes.y.dim,
output_size=self.Kv,
width=N_2,
depth=D_2,
act=self.a_y,
)

def forward(
self, x: torch.Tensor, u: torch.Tensor, y: torch.Tensor
) -> torch.Tensor:
"""Forward pass through the operator.
Args:
x: Sensor positions of shape (batch_size, #sensors, x_dim).
u: Input function values of shape (batch_size, #sensors, u_dim)
y: Evaluation coordinates of shape (batch_size, #evaluations, y_dim)
Returns:
Operator output (batch_size, #evaluations, v_dim)
"""
assert x.size(0) == u.size(0) == y.size(0)
num_evaluations = y.size(1)

# flatten inputs
x = x.reshape(-1, self.Nx)
u = u.reshape(-1, self.Nu)
y = y.reshape(-1, self.shapes.y.dim)

# build projection matrix
P = torch.stack([p(x) for p in self.p], dim=1)
assert P.shape[1:] == torch.Size([self.K, self.Nu])

# perform the projection
aPu = self.a_u(torch.einsum("bkn,bn->bk", P, u))
assert aPu.shape[1:] == torch.Size([self.K])

# construction net
Q = self.q(y)
assert Q.shape[1:] == torch.Size([self.Kv])

# dot product
Q = Q.reshape(-1, num_evaluations, self.K, self.shapes.v.dim)
output = torch.einsum("bk,bckv->bcv", aPu, Q)
assert output.shape[1:] == torch.Size([num_evaluations, self.shapes.v.dim])

return output
41 changes: 38 additions & 3 deletions src/continuity/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,51 @@
"""

import torch
from typing import Optional


class FullyConnected(torch.nn.Module):
"""Fully connected network.
Args:
input_size: Input dimension.
output_size: Output dimension.
width: Width of the hidden layer.
act: Activation function.
"""

def __init__(
self,
input_size: int,
output_size: int,
width: int,
act: Optional[torch.nn.Module] = None,
):
super().__init__()
self.inner_layer = torch.nn.Linear(input_size, width)
self.outer_layer = torch.nn.Linear(width, output_size)
self.act = act or torch.nn.Tanh()

def forward(self, x: torch.Tensor):
"""Forward pass."""
x = self.inner_layer(x)
x = self.act(x)
x = self.outer_layer(x)
return x


class ResidualLayer(torch.nn.Module):
"""Residual layer.
Args:
width: Width of the layer.
act: Activation function.
"""

def __init__(self, width: int):
def __init__(self, width: int, act: Optional[torch.nn.Module] = None):
super().__init__()
self.layer = torch.nn.Linear(width, width)
self.act = torch.nn.Tanh()
self.act = act or torch.nn.Tanh()

def forward(self, x: torch.Tensor):
"""Forward pass."""
Expand All @@ -32,6 +64,7 @@ class DeepResidualNetwork(torch.nn.Module):
output_size: Size of output tensor
width: Width of hidden layers
depth: Number of hidden layers
act: Activation function
"""

def __init__(
Expand All @@ -40,12 +73,14 @@ def __init__(
output_size: int,
width: int,
depth: int,
act: Optional[torch.nn.Module] = None,
):
super().__init__()

self.act = act or torch.nn.Tanh()
self.first_layer = torch.nn.Linear(input_size, width)
self.hidden_layers = torch.nn.ModuleList(
[ResidualLayer(width) for _ in range(depth)]
[ResidualLayer(width, act=self.act) for _ in range(depth)]
)
self.last_layer = torch.nn.Linear(width, output_size)

Expand Down
81 changes: 81 additions & 0 deletions tests/operators/test_belnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import matplotlib.pyplot as plt
import pytest

from continuity.plotting import plot, plot_evaluation
from torch.utils.data import DataLoader
from continuity.operators import BelNet
from continuity.data import OperatorDataset
from continuity.data.sine import OperatorDataset, Sine
from continuity.trainer import Trainer
from continuity.operators.losses import MSELoss


def test_belnet_shape():
x_dim = 2
u_dim = 3
y_dim = 5
v_dim = 7
n_sensors = 11
n_evals = 13
batch_size = 17
set_size = 19

dset = OperatorDataset(
x=torch.rand((set_size, n_sensors, x_dim)),
u=torch.rand((set_size, n_sensors, u_dim)),
y=torch.rand((set_size, n_evals, y_dim)),
v=torch.rand((set_size, n_evals, v_dim)),
)

model = BelNet(dset.shapes)

x, u, y, v = dset[:batch_size]

v_pred = model(x, u, y)

assert v_pred.shape == v.shape

y_other = torch.rand((batch_size, n_evals, y_dim))
v_other = torch.rand((batch_size, n_evals, v_dim))

v_other_pred = model(x, u, y_other)

assert v_other_pred.shape == v_other.shape


@pytest.mark.slow
def test_belnet():
# Parameters
num_sensors = 16

# Data set
dataset = Sine(num_sensors, size=1)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Operator
operator = BelNet(
dataset.shapes,
)

# Train self-supervised
optimizer = torch.optim.Adam(operator.parameters(), lr=1e-3)
trainer = Trainer(operator, optimizer)
trainer.fit(data_loader, epochs=1000)

# Plotting
fig, ax = plt.subplots(1, 1)
x, u, _, _ = dataset[0]
plot(x, u, ax=ax)
plot_evaluation(operator, x, u, ax=ax)
fig.savefig(f"test_belnet.png")

# Check solution
x = x.unsqueeze(0)
u = u.unsqueeze(0)
assert MSELoss()(operator, x, u, x, u) < 1e-3


if __name__ == "__main__":
test_belnet_shape()
test_belnet()

0 comments on commit 3854baf

Please sign in to comment.