Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561841074
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Sep 1, 2023
1 parent 8386eee commit 74ef759
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 9 deletions.
96 changes: 88 additions & 8 deletions swirl_dynamics/lib/networks/fno.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np

# ********************
# Layers
Expand Down Expand Up @@ -72,6 +73,8 @@ class SpectralConv(nn.Module):
out_channels: The number of output channels.
num_modes: The number of modes in the kernel, at most (Nx // 2 + 1); length
must agree with the number of spatial dimensions.
domain_size: The spatial dimensions of the inverse fft transform; if not
specified, the input spatial dimensions are used.
use_bias: Whether to add a bias after the spectral convolution.
fft_norm: Arg to `jnp.fft.rfftn` and `jnp.fft.irfftn`; choose from
`backward`, `ortho` or `forward`.
Expand All @@ -85,6 +88,7 @@ class SpectralConv(nn.Module):
in_channels: int
out_channels: int
num_modes: tuple[int, ...]
domain_size: tuple[int, ...] | None = None
use_bias: bool = True
fft_norm: Literal["backward", "ortho", "forward"] = "backward"
contract_fn: ContractFnType = ContractFnType.DENSE
Expand Down Expand Up @@ -137,6 +141,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
f"Input `in_channels` ({in_channels}) is inconsistent with the"
f" `self.in_channels` declaration ({self.in_channels})."
)
domain_size = self.domain_size or spatial_dims

fft_axes = tuple(range(1, x.ndim - 1))
x = jnp.fft.rfftn(x, axes=fft_axes, norm=self.fft_norm)
Expand All @@ -158,7 +163,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.fft.fftshift(x[mode_slices], axes=fft_axes[:-1])

x = self._contract_fn(x, self.weights, separable=self.separable)
x = jnp.fft.irfftn(x, s=spatial_dims, axes=fft_axes, norm=self.fft_norm)
x = jnp.fft.irfftn(x, s=domain_size, axes=fft_axes, norm=self.fft_norm)

if self.use_bias:
x = x + self.bias[None]
Expand All @@ -177,19 +182,19 @@ class FnoResBlock(nn.Module):
The dimension is indicated by `len(num_modes)`.
Attributes:
out_channels: the number of output channels.
num_modes: the number of modes used for spectral conv layers; see
out_channels: The number of output channels.
num_modes: The number of modes used for spectral conv layers; see
class::`SpectralConv.num_modes`.
num_layers: the number of spectral conv layers in the block.
num_layers: The number of spectral conv layers in the block.
fft_norm: arg to `jnp.fft.rfftn` and `jnp.fft.irfftn` in spectral conv
layers.
contract_fn: the type of contraction function to be used for spectral conv;
contract_fn: The type of contraction function to be used for spectral conv;
see class::`SpectralConv.contract_fn`.
separable: whether to use a separable contraction function.
separable: Whether to use a separable contraction function.
act_fn: the activation function.
skip_type: the type of skip connection to use - choose from ["linear",
skip_type: The type of skip connection to use - choose from ["linear",
"soft-gate", "identity"].
param_dtype: the dtype of model parameters.
param_dtype: The dtype of model parameters.
"""

out_channels: int
Expand Down Expand Up @@ -308,3 +313,78 @@ def __call__(self, x: jax.Array) -> jax.Array:
x = self.act_fn(x)
x = nn.Dense(features=self.out_channels)(x)
return x


class Fno2d(nn.Module):
"""2-dimensional FNO network.
This network structure and default configs follow
https://github.com/neuraloperator/markov_neural_operator/blob/main/models/fno_2d.py
Attributes:
out_channels: The number of output channels.
num_modes: The base number of modes for the spectral conv layers (scaled
with depth); see class::`SpectralConv.num_modes`.
width: The base number of features in the intermediate layers (scaled with
depth).
num_layers: The number of spectral conv layers in the block.
domain_size: Arg to the spectral conv layers; see
class::`SpectralConv.domain_size`.
fft_norm: Arg to `jnp.fft.rfftn` and `jnp.fft.irfftn` in spectral conv
layers.
act_fn: The activation function.
param_dtype: The dtype of model parameters.
"""

out_channels: int
num_modes: tuple[int, int] = (20, 20)
width: int = 128
domain_size: tuple[int, int] | None = None
fft_norm: Literal["backward", "ortho", "forward"] = "ortho"
act_fn: Callable[[jax.Array], jax.Array] = jax.nn.selu
param_dtype: jnp.dtype = jnp.complex64
grid_dtype: jnp.dtype = jnp.float32

@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
batch_sz, *grid_size, _ = x.shape
grid = self.get_grid(tuple(grid_size), dtype=self.grid_dtype)
grid = jnp.tile(grid, (batch_sz,) + (1,) * (len(grid_size) + 1))

# Scaling follows the reference repo in the class description
widths = np.asarray([2, 3, 4, 4, 5], dtype=np.int32) * self.width // 4
modes = np.outer(np.asarray([4, 3, 2, 2]), np.asarray(self.num_modes)) // 4
kernel_sz = (1,) * len(grid_size)

x = jnp.concatenate([x, grid], axis=-1)
x = nn.Dense(widths[0])(x)

for i in range(4):
h1 = SpectralConv(
in_channels=x.shape[-1],
out_channels=widths[i + 1],
domain_size=self.domain_size,
num_modes=tuple(modes[i]),
fft_norm=self.fft_norm,
)(x)
h2 = nn.Conv(features=widths[i + 1], kernel_size=kernel_sz)(x)
x = h1 + h2
if i < 4:
x = self.act_fn(x)

x = self.act_fn(nn.Dense(features=widths[-1] * 2)(x))
x = self.act_fn(nn.Dense(features=widths[-1] * 2)(x))
x = nn.Dense(features=self.out_channels)(x)
return x

def get_grid(self, grid_size: tuple[int, int], dtype: jnp.dtype) -> jax.Array:
sz_x, sz_y = grid_size
grid_x = jnp.expand_dims(
jnp.linspace(0, 1, sz_x, endpoint=False, dtype=dtype), (1, 2)
)
grid_x = jnp.tile(grid_x, (1, sz_y, 1))
grid_y = jnp.expand_dims(
jnp.linspace(0, 1, sz_y, endpoint=False, dtype=dtype), (0, 2)
)
grid_y = jnp.tile(grid_y, (sz_x, 1, 1))
return jnp.concatenate([grid_x, grid_y], axis=-1)
22 changes: 21 additions & 1 deletion swirl_dynamics/lib/networks/fno_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class FnoTest(parameterized.TestCase):
n_dim=(1, 2, 3),
num_modes=(2, 3),
spatial_dim=(8, 9),
domain_size=(None, 16),
contract_fn=(fno.ContractFnType.DENSE,),
separable=(True, False),
weights_dtype=(jnp.complex64,),
Expand All @@ -34,6 +35,7 @@ def test_spectral_conv(
n_dim,
num_modes,
spatial_dim,
domain_size,
contract_fn,
separable,
weights_dtype,
Expand All @@ -42,17 +44,19 @@ def test_spectral_conv(
in_channels, out_channels = 5, 5
input_shape = (batch_sz,) + (spatial_dim,) * n_dim + (in_channels,)
inputs = jax.random.normal(jax.random.PRNGKey(0), input_shape)
domain_size = domain_size or spatial_dim
layer = fno.SpectralConv(
in_channels=in_channels,
out_channels=out_channels,
num_modes=(num_modes,) * n_dim,
domain_size=(domain_size,) * n_dim,
contract_fn=contract_fn,
separable=separable,
weights_dtype=weights_dtype,
)
layer_vars = layer.init(jax.random.PRNGKey(0), inputs)
out = jax.jit(layer.apply)(layer_vars, inputs)
out_shape = (batch_sz,) + (spatial_dim,) * n_dim + (out_channels,)
out_shape = (batch_sz,) + (domain_size,) * n_dim + (out_channels,)
self.assertEqual(out.shape, out_shape)

@parameterized.product(
Expand Down Expand Up @@ -103,6 +107,22 @@ def test_fno(self, n_dim):
out_shape = (batch_sz,) + (spatial_dim,) * n_dim + (out_channels,)
self.assertEqual(out.shape, out_shape)

def test_fno_2d(self):
batch_sz = 2
in_channels, out_channels = 1, 1
num_modes = (4, 4)
width = 5
spatial_dim = 12
input_shape = (batch_sz,) + (spatial_dim,) * 2 + (in_channels,)
inputs = jax.random.normal(jax.random.PRNGKey(0), input_shape)
model = fno.Fno2d(
out_channels=out_channels, num_modes=num_modes, width=width
)
model_vars = model.init(jax.random.PRNGKey(0), inputs)
out = jax.jit(model.apply)(model_vars, inputs)
out_shape = (batch_sz,) + (spatial_dim,) * 2 + (out_channels,)
self.assertEqual(out.shape, out_shape)


if __name__ == "__main__":
absltest.main()

0 comments on commit 74ef759

Please sign in to comment.