diff --git a/tests/pde/test_deepxde.py b/tests/pde/deepxde_example.py similarity index 98% rename from tests/pde/test_deepxde.py rename to tests/pde/deepxde_example.py index b349dc94..fd0179d3 100644 --- a/tests/pde/test_deepxde.py +++ b/tests/pde/deepxde_example.py @@ -6,7 +6,7 @@ torch.manual_seed(0) -def test_deepxde(): +def deepxde_example(): """Physics-informed DeepONet for Poisson equation in 1D. Example from DeepXDE. https://deepxde.readthedocs.io/en/latest/demos/operator/poisson.1d.pideeponet.html @@ -96,4 +96,4 @@ def boundary(_, on_boundary): if __name__ == "__main__": - test_deepxde() + deepxde_example() diff --git a/tests/pde/test_pideeponet.py b/tests/pde/test_pideeponet.py index 7b90e2ba..c1104ad1 100644 --- a/tests/pde/test_pideeponet.py +++ b/tests/pde/test_pideeponet.py @@ -77,11 +77,9 @@ def v_boundary(y): # operator = cti.operators.DeepNeuralOperator(dataset.shapes) # Define and train model - trainer = cti.Trainer( - operator, - loss_fn=cti.pde.PhysicsInformedLoss(equation), - ) - trainer.fit(dataset) + loss_fn = cti.pde.PhysicsInformedLoss(equation) + trainer = cti.Trainer(operator, loss_fn=loss_fn) + trainer.fit(dataset, epochs=100) # Plot realizations of f(x) n = 3 @@ -115,7 +113,7 @@ def v_boundary(y): plt.plot(y, v[i].T, "-") plt.xlabel("x") - plt.show() + plt.savefig("pideeponet.png", dpi=500) if __name__ == "__main__":