Skip to content

Commit

Permalink
Wrap FNO from NVIDIA Modulus.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Jul 12, 2024
1 parent f07fe22 commit 0a87a1c
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/continuiti/operators/modulus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
`continuiti.operators.modulus`
Operators from NVIDIA Modulus wrapped in continuiti.
"""

# Test if we can import NVIDIA Modulus
try:
import modulus # noqa: F40
except ImportError:
raise ImportError("NVIDIA Modulus not found!")

from .fno import FNO

__all__ = [
"FNO",
]
61 changes: 61 additions & 0 deletions src/continuiti/operators/modulus/fno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
`continuiti.operators.modulus.fno`
The Fourier Neural Operator from NVIDIA Modulus wrapped in continuiti.
"""

import torch
from typing import Optional
from continuiti.operators import Operator, OperatorShapes
from modulus.models.fno import FNO as FNOModulus


class FNO(Operator):
r"""FNO architecture from NVIDIA Modulus.
The `in_channels` and `out_channels` arguments are determined by the
`shapes` argument. The `dimension` is set to the dimension of the input
coordinates, assuming that the grid dimension is the same as the coordinate
dimension of `x`.
All other keyword arguments are passed to the Fourier Neural Operator, please refer
to the documentation of the `modulus.model.fno.FNO` class for more information.
Args:
shapes: Shapes of the input and output data.
device: Device.
**kwargs: Additional arguments for the Fourier layers.
"""

def __init__(
self,
shapes: OperatorShapes,
device: Optional[torch.device] = None,
dimension: Optional[int] = None,
**kwargs,
):
super().__init__(shapes, device)

if dimension is None:
# Per default, use coordinate dimension
dimension = shapes.x.dim

self.fno = FNOModulus(
in_channels=shapes.u.dim,
out_channels=shapes.v.dim,
dimension=dimension,
**kwargs,
)
self.fno.to(device)

def forward(
self, x: torch.Tensor, u: torch.Tensor, y: torch.Tensor
) -> torch.Tensor:
r"""Forward pass of the Fourier Neural Operator.
Args:
x: Ignored.
u: Input function values of shape (batch_size, u_dim, num_sensors...).
y: Ignored.
"""
return self.fno(u)
Empty file.
53 changes: 53 additions & 0 deletions tests/operators/modulus/test_modulus_fno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from continuiti.benchmarks.sine import SineBenchmark
from continuiti.trainer import Trainer
from continuiti.operators.modulus import FNO
from continuiti.operators.losses import MSELoss


@pytest.mark.slow
def test_modulus_fno():
try:
import modulus # noqa: F401
except ImportError:
pytest.skip("NVIDIA Modulus not found!")

# Data set
benchmark = SineBenchmark(n_train=1)
dataset = benchmark.train_dataset

# Operator
# Configured like the default continuiti `FourierNeuralOperator`
# with depth=3 and width=3 as in `test_fno.py`.
operator = FNO(
dataset.shapes,
decoder_layers=1,
decoder_layer_size=1,
decoder_activation_fn="identity",
num_fno_layers=3, # "depth" in FourierNeuralOperator
latent_channels=3, # "width" in FourierNeuralOperator
num_fno_modes=dataset.shapes.u.size[0] // 2 + 1,
padding=0,
coord_features=False,
)

# Train
Trainer(operator, device="cpu").fit(dataset, tol=1e-12, epochs=10_000)

# Check solution
x, u, y, v = dataset.x, dataset.u, dataset.y, dataset.v
assert MSELoss()(operator, x, u, y, v) < 1e-12


# SineBenchmark(n_train=1024, n_sensors=128, n_evaluations=128), epochs=100

# NVIDIA Modulus FNO
# Parameters: 3560 Device: cpu
# Epoch 100/100 Step 32/32 [====================] 6ms/step [0:19min<0:00min] - loss/train = 6.3876e-05

# continuiti FNO
# Parameters: 3556 Device: cpu
# Epoch 100/100 Step 32/32 [====================] 3ms/step [0:10min<0:00min] - loss/train = 1.4440e-04

# -> continuiti FNO is 2x faster than NVIDIA Modulus FNO
# -> NVIDIA Modulus FNO can not handle different number of sensors and evaluations

0 comments on commit 0a87a1c

Please sign in to comment.