Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588281168
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Dec 6, 2023
1 parent de2197c commit acd987c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
21 changes: 15 additions & 6 deletions swirl_dynamics/lib/diffusion/unets.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ class LatLonConv(nn.Module):
"""

features: int
kernel_size: tuple[int, int] = (3, 3)
kernel_size: tuple[int, ...] | list[int] = (3, 3)
use_bias: bool = True
kernel_init: Initializer = default_init(1.0)
strides: tuple[int, ...] | list[int] = (1, 1)
use_local: bool = False

@nn.compact
def __call__(self, x: Array) -> Array:
Expand All @@ -153,11 +155,13 @@ def __call__(self, x: Array) -> Array:
x_per, ((0, 0), (lon_pad, lon_pad), (0, 0), (0, 0)), mode="edge"
)
# shape: (batch_size, lon, lat, features)
return nn.Conv(
conv_fn = nn.ConvLocal if self.use_local else nn.Conv

return conv_fn(
self.features,
kernel_size=self.kernel_size,
use_bias=self.use_bias,
strides=(1, 1),
strides=self.strides,
kernel_init=self.kernel_init,
padding="VALID",
)(x_per)
Expand Down Expand Up @@ -189,9 +193,14 @@ def __call__(self, x: Array) -> Array:
return x


def conv_layer(padding: str, **kwargs) -> nn.Module:
if padding.lower() == "latlon":
return LatLonConv(**kwargs)
def conv_layer(
padding: str | int, use_local: bool = False, **kwargs
) -> nn.Module:
"""Wrapper for conv layers with non-standard boundary conditions."""
if isinstance(padding, str) and padding.lower() == "latlon":
return LatLonConv(use_local=use_local, **kwargs)
elif use_local:
return nn.ConvLocal(padding=padding, **kwargs)
else:
return nn.Conv(padding=padding, **kwargs)

Expand Down
22 changes: 22 additions & 0 deletions swirl_dynamics/lib/diffusion/unets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ def test_channelwise_conditioning_output_shape(self, x_dims, c_dims):
)
self.assertEqual(out.shape, x.shape)

@parameterized.parameters(((8, 8),), ((4, 8),))
def test_latlon_conv_layer_output_shape_and_equivariance(self, spatial_dims):
batch, channels = 2, 1
x = np.random.randn(batch, *spatial_dims, channels)

model = unets.conv_layer(
features=1,
kernel_size=(3, 3),
padding="LATLON",
use_bias=False,
)
out, variables = model.init_with_output(
jax.random.PRNGKey(42), x
)
x_roll = np.roll(x, shift=3, axis=2)
out_roll = model.apply(variables, x_roll)

self.assertEqual(out.shape, x.shape)
self.assertEqual(out_roll.shape, x_roll.shape)
self.assertEqual(
jnp.roll(out, shift=3, axis=2).tolist(), out_roll.tolist()
)

if __name__ == "__main__":
absltest.main()

0 comments on commit acd987c

Please sign in to comment.