Skip to content

Commit

Permalink
Add 1D Poisson PI-DeepONet example as in DeepXDE.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Apr 19, 2024
1 parent 7cb5bf5 commit 104146a
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ dev = [
"neoteroi-mkdocs",
"pygments",
"gmsh",
"deepxde",
]

[tool.setuptools.dynamic]
Expand Down
21 changes: 21 additions & 0 deletions src/continuiti/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,24 @@
"""
__version__ = "0.0.0"

__all__ = [
"benchmarks",
"data",
"discrete",
"operators",
"pde",
"trainer",
"transforms",
"Trainer",
]

from . import benchmarks
from . import data
from . import discrete
from . import operators
from . import pde
from . import trainer
from . import transforms

from .trainer import Trainer
99 changes: 99 additions & 0 deletions tests/pde/test_deepxde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(0)


def test_deepxde():
"""Physics-informed DeepONet for Poisson equation in 1D.
Example from DeepXDE.
https://deepxde.readthedocs.io/en/latest/demos/operator/poisson.1d.pideeponet.html
"""

# Poisson equation: -u_xx = f
def equation(x, y, f):
dy_xx = dde.grad.hessian(y, x)
return -dy_xx - f

# Domain is interval [0, 1]
geom = dde.geometry.Interval(0, 1)

# Zero Dirichlet BC
def u_boundary(_):
return 0

def boundary(_, on_boundary):
return on_boundary

bc = dde.icbc.DirichletBC(geom, u_boundary, boundary)

# Define PDE
pde = dde.data.PDE(geom, equation, bc, num_domain=100, num_boundary=2)

# Function space for f(x) are polynomials
degree = 3
space = dde.data.PowerSeries(N=degree + 1)

# Choose evaluation points
num_eval_points = 10
evaluation_points = geom.uniform_points(num_eval_points, boundary=True)

# Define PDE operator
pde_op = dde.data.PDEOperatorCartesianProd(
pde,
space,
evaluation_points,
num_function=100,
)

# Setup DeepONet
dim_x = 1
p = 32
net = dde.nn.DeepONetCartesianProd(
[num_eval_points, 32, p],
[dim_x, 32, p],
activation="tanh",
kernel_initializer="Glorot normal",
)
print("Params:", sum(p.numel() for p in net.parameters()))

# Define and train model
model = dde.Model(pde_op, net)
model.compile("adam", lr=0.001)
model.train(epochs=1000)

# Plot realizations of f(x)
n = 3
features = space.random(n)
fx = space.eval_batch(features, evaluation_points)

x = geom.uniform_points(100, boundary=True)
y = model.predict((fx, x))

# Setup figure
fig = plt.figure(figsize=(7, 8))
plt.subplot(2, 1, 1)
plt.title("Poisson equation: Source term f(x) and solution u(x)")
plt.ylabel("f(x)")
z = np.zeros_like(x)
plt.plot(x, z, "k-", alpha=0.1)

# Plot source term f(x)
for i in range(n):
plt.plot(evaluation_points, fx[i], ".")

# Plot solution u(x)
plt.subplot(2, 1, 2)
plt.ylabel("u(x)")
plt.plot(x, z, "k-", alpha=0.1)
for i in range(n):
plt.plot(x, y[i], "-")
plt.xlabel("x")

plt.show()


if __name__ == "__main__":
test_deepxde()
122 changes: 122 additions & 0 deletions tests/pde/test_pideeponet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import torch
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
import continuiti as cti

torch.manual_seed(0)


def test_pideeponet():
"""Physics-informed DeepONet for Poisson equation in 1D.
Example from DeepXDE in *continuiti*.
https://deepxde.readthedocs.io/en/latest/demos/operator/poisson.1d.pideeponet.html
"""

# Poisson equation: -v_xx = f
mse = torch.nn.MSELoss()

def equation(_, f, y, v):
# PDE
dy_xx = dde.grad.hessian(v, y)
inner_loss = mse(-dy_xx, f)

# BC
y_bnd, v_bnd = y[:, :, bnd_indices], v[:, :, bnd_indices]
boundary_loss = mse(v_bnd, v_boundary(y_bnd))

return inner_loss + boundary_loss

# Domain is interval [0, 1]
geom = dde.geometry.Interval(0, 1)

# Zero Dirichlet BC
def v_boundary(y):
return torch.zeros_like(y)

# Sample domain and boundary points
num_domain = 100
num_boundary = 2
x_domain = geom.uniform_points(num_domain)
x_bnd = geom.uniform_boundary_points(num_boundary)

x = np.concatenate([x_domain, x_bnd])
num_points = len(x)
bnd_indices = range(num_domain, num_points)

# Function space for f(x) are polynomials
degree = 3
space = dde.data.PowerSeries(N=degree + 1)

num_functions = 100
coeffs = space.random(num_functions)
fx = space.eval_batch(coeffs, x)

# Specify dataset
xt = torch.tensor(x.T).requires_grad_(True)
x_all = xt.expand(num_functions, -1, -1) # (num_functions, x_dim, num_domain)
u_all = torch.tensor(fx).unsqueeze(1) # (num_functions, u_dim, num_domain)
y_all = x_all # (num_functions, y_dim, num_domain)
v_all = torch.zeros_like(y_all) # (num_functions, v_dim, num_domain)

dataset = cti.data.OperatorDataset(
x=x_all,
u=u_all,
y=y_all, # same as x_all
v=v_all, # only for shapes
)

# Define operator
operator = cti.operators.DeepONet(
dataset.shapes,
trunk_depth=1,
branch_depth=1,
basis_functions=32,
)
# or any other operator, e.g.:
# operator = cti.operators.DeepNeuralOperator(dataset.shapes)

# Define and train model
trainer = cti.Trainer(
operator,
loss_fn=cti.pde.PhysicsInformedLoss(equation),
)
trainer.fit(dataset)

# Plot realizations of f(x)
n = 3
features = space.random(n)
fx = space.eval_batch(features, x)

y = geom.uniform_points(200, boundary=True)

x_plot = torch.tensor(x_domain.T).expand(n, -1, -1)
u_plot = torch.tensor(fx).unsqueeze(1)
y_plot = torch.tensor(y.T).expand(n, -1, -1)
v = operator(x_plot, u_plot, y_plot)
v = v.detach().numpy()

fig = plt.figure(figsize=(7, 8))
plt.subplot(2, 1, 1)
plt.title("Poisson equation: Source term f(x) and solution v(x)")
plt.ylabel("f(x)")
z = np.zeros_like(y)
plt.plot(y, z, "k-", alpha=0.1)

# Plot source term f(x)
for i in range(n):
plt.plot(x, fx[i], ".")

# Plot solution v(x)
plt.subplot(2, 1, 2)
plt.ylabel("v(x)")
plt.plot(y, z, "k-", alpha=0.1)
for i in range(n):
plt.plot(y, v[i].T, "-")
plt.xlabel("x")

plt.show()


if __name__ == "__main__":
test_pideeponet()

0 comments on commit 104146a

Please sign in to comment.