Skip to content

Commit

Permalink
Accelerate test execution.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Apr 22, 2024
1 parent 2968f1f commit 4e98a26
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 24 deletions.
1 change: 1 addition & 0 deletions tests/benchmarks/run/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ def test_runner():
benchmark_factory=SineRegular,
operator_factory=DeepNeuralOperator,
max_epochs=2,
batch_size=128,
)
BenchmarkRunner.run(config)
32 changes: 17 additions & 15 deletions tests/benchmarks/test_navierstokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@ def test_navierstokes_shapes_and_plot():
assert y.shape == (3, 64, 64, 10)
assert v.shape == (1, 64, 64, 10)

fig, axs = plt.subplots(1, 2, subplot_kw={"projection": "3d"}, figsize=(10, 5))
x, u, y, v = benchmark.test_dataset[0]
axs[0].scatter(x[2], x[0], x[1], s=1, c=u, cmap="jet", alpha=0.7)
axs[1].scatter(y[2], y[0], y[1], s=1, c=v, cmap="jet", alpha=0.7)
for i in range(2):
axs[i].set_xlabel("t")
axs[i].set_ylabel("x")
axs[i].set_zlabel("y")
axs[0].set_title("Input")
axs[1].set_title("Output")

try:
fig.savefig("docs/benchmarks/img/navierstokes.png", dpi=500)
except FileNotFoundError:
pass
plot = False
if plot:
fig, axs = plt.subplots(1, 2, subplot_kw={"projection": "3d"}, figsize=(10, 5))
x, u, y, v = benchmark.test_dataset[0]
axs[0].scatter(x[2], x[0], x[1], s=1, c=u, cmap="jet", alpha=0.7)
axs[1].scatter(y[2], y[0], y[1], s=1, c=v, cmap="jet", alpha=0.7)
for i in range(2):
axs[i].set_xlabel("t")
axs[i].set_ylabel("x")
axs[i].set_zlabel("y")
axs[0].set_title("Input")
axs[1].set_title("Output")

try:
fig.savefig("docs/benchmarks/img/navierstokes.png", dpi=500)
except FileNotFoundError:
pass
6 changes: 4 additions & 2 deletions tests/pde/test_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch
import deepxde as dde
import matplotlib.pyplot as plt
Expand All @@ -7,6 +8,7 @@
torch.manual_seed(0)


@pytest.mark.slow
def test_pideeponet():
"""Physics-informed DeepONet for Poisson equation in 1D.
Example from DeepXDE in *continuiti*.
Expand Down Expand Up @@ -35,7 +37,7 @@ def v_boundary(y):
return torch.zeros_like(y)

# Sample domain and boundary points
num_domain = 100
num_domain = 32
num_boundary = 2
x_domain = geom.uniform_points(num_domain)
x_bnd = geom.uniform_boundary_points(num_boundary)
Expand All @@ -48,7 +50,7 @@ def v_boundary(y):
degree = 3
space = dde.data.PowerSeries(N=degree + 1)

num_functions = 100
num_functions = 10
coeffs = space.random(num_functions)
fx = space.eval_batch(coeffs, x)

Expand Down
14 changes: 7 additions & 7 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@


def train():
dataset = SineBenchmark().train_dataset
dataset = SineBenchmark(n_train=32).train_dataset
operator = DeepONet(dataset.shapes, trunk_depth=16)

Trainer(operator).fit(dataset, tol=1e-3)
Trainer(operator).fit(dataset, tol=1e-2)

# Make sure we can use operator output on cpu again
x, u, y, v = dataset.x, dataset.u, dataset.y, dataset.v
v_pred = operator(x, u, y)
assert ((v_pred - v.to("cpu")) ** 2).mean() < 1e-3
assert ((v_pred - v.to("cpu")) ** 2).mean() < 1e-2


@pytest.mark.slow
Expand All @@ -42,7 +42,7 @@ def f(x):
input_size=1,
output_size=1,
width=32,
depth=3,
depth=8,
)

# Define loss function (in continuiti style)
Expand All @@ -56,14 +56,14 @@ def loss_fn(op, x, y):
trainer = Trainer(model, loss_fn=loss_fn)
logs = trainer.fit(
train_dataset,
tol=1e-3,
tol=1e-2,
test_dataset=test_dataset,
)

# Test the model
assert logs.loss_test < 1e-3
assert logs.loss_test < 1e-2


# Use ./run_parallel.sh to run test with CUDA
if __name__ == "__main__":
train()
test_trainer_with_torch_model()

0 comments on commit 4e98a26

Please sign in to comment.