Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588674014
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Dec 7, 2023
1 parent 7d886f4 commit 26c2b2b
Show file tree
Hide file tree
Showing 3 changed files with 637 additions and 5 deletions.
10 changes: 5 additions & 5 deletions swirl_dynamics/lib/diffusion/unets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from swirl_dynamics.lib.diffusion import unets
from swirl_dynamics.lib import diffusion # gpylint: disable=g-importing-member


class NetworksTest(parameterized.TestCase):
Expand All @@ -34,7 +34,7 @@ def test_unet_output_shape(self, spatial_dims, padding, ds_ratio):
batch, channels = 2, 3
x = np.random.randn(batch, *spatial_dims, channels)
sigma = np.linspace(0, 1, batch)
model = unets.UNet(
model = diffusion.unets.UNet(
out_channels=channels,
num_channels=(4, 8, 12),
downsample_ratio=ds_ratio,
Expand All @@ -53,7 +53,7 @@ def test_preconditioned_denoiser_output_shape(self, spatial_dims):
batch, channels = 2, 3
x = np.random.randn(batch, *spatial_dims, channels)
sigma = np.linspace(0, 1, batch)
model = unets.PreconditionedDenoiser(
model = diffusion.unets.PreconditionedDenoiser(
out_channels=channels,
num_channels=(4, 8, 12),
downsample_ratio=(2, 2, 2),
Expand All @@ -80,7 +80,7 @@ def test_channelwise_conditioning_output_shape(self, x_dims, c_dims):
x = jax.random.normal(jax.random.PRNGKey(42), x_dims)
cond = {"channel:cond1": jax.random.normal(jax.random.PRNGKey(42), c_dims)}
sigma = jnp.array(0.5)
model = unets.PreconditionedDenoiser(
model = diffusion.unets.PreconditionedDenoiser(
out_channels=x_dims[-1],
num_channels=(4, 8, 12),
downsample_ratio=(2, 2, 2),
Expand All @@ -106,7 +106,7 @@ def test_latlon_conv_layer_output_shape_and_equivariance(self, spatial_dims):
batch, channels = 2, 1
x = np.random.randn(batch, *spatial_dims, channels)

model = unets.conv_layer(
model = diffusion.unets.conv_layer(
features=1,
kernel_size=(3, 3),
padding="LATLON",
Expand Down
Loading

0 comments on commit 26c2b2b

Please sign in to comment.