Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616880701
  • Loading branch information
ilopezgp authored and The swirl_dynamics Authors committed Mar 18, 2024
1 parent 4ae652a commit b4757fd
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 11 deletions.
7 changes: 5 additions & 2 deletions swirl_dynamics/lib/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@
exponential_noise_decay,
uniform_time,
)
from swirl_dynamics.lib.diffusion.unets import (
AxialMLPInterpConvMerge,
InterpConvMerge,
UNet,
)
from swirl_dynamics.lib.diffusion.unets import PreconditionedDenoiser as PreconditionedDenoiserUNet
from swirl_dynamics.lib.diffusion.unets import UNet
from swirl_dynamics.lib.diffusion.unets3d import PreconditionedDenoiser3d as PreconditionedDenoiserUNet3d
from swirl_dynamics.lib.diffusion.unets3d import UNet3d
from swirl_dynamics.lib.diffusion.vivit import ViViT
from swirl_dynamics.lib.diffusion.vivit_diffusion import PreconditionedDenoiser as PreconditionedDenoiserViViT
from swirl_dynamics.lib.diffusion.vivit_diffusion import ViViTDiffusion

75 changes: 73 additions & 2 deletions swirl_dynamics/lib/diffusion/unets.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,48 @@ def position_embedding(ndim: int, **kwargs) -> nn.Module:
raise ValueError("Only 1D or 2D position embeddings are supported.")


class Axial2DMLP(nn.Module):
"""Applies axial spatial perceptrons to an input tensor of 2D fields.
The inputs are assumed to have the shape (..., *spatial_dims, channels),
with two spatial dimensions.
Attributes:
out_dims: A tuple containing the output spatial dimensions.
act_fn: The activation function.
"""

out_dims: tuple[int, int]
act_fn: Callable[[Array], Array] = nn.swish

@nn.compact
def __call__(self, x: Array) -> Array:

num_channels = x.shape[-1]

for i, out_dim in enumerate(self.out_dims):
spatial_dim = -3 + i
x = jnp.swapaxes(x, spatial_dim, -1)
x = nn.Dense(features=out_dim)(x)
x = nn.swish(x)
x = jnp.swapaxes(x, -1, spatial_dim)

x = nn.Dense(features=num_channels)(x)
return x


class MergeChannelCond(nn.Module):
"""Merges conditional inputs along the channel dimension."""
"""Base class for merging conditional inputs along the channel dimension."""

embed_dim: int
kernel_size: Sequence[int]
resize_method: str = "cubic"
padding: str = "CIRCULAR"


class InterpConvMerge(MergeChannelCond):
"""Merges conditional inputs through interpolation and convolutions."""

@nn.compact
def __call__(self, x: Array, cond: dict[str, Array]):
"""Merges conditional inputs along the channel dimension.
Expand Down Expand Up @@ -284,6 +318,42 @@ def __call__(self, x: Array, cond: dict[str, Array]):
return x


class AxialMLPInterpConvMerge(MergeChannelCond):
"""Merges conditional inputs through MLPs, interpolation and convolutions."""

@nn.compact
def __call__(self, x: Array, cond: dict[str, Array]):
"""Transforms and merges conditional inputs along the channel dimension.
Relevant fields in the conditional input dictionary are first passed through
axial perceptrons, then resized, and finally concatenated with the main
input along their last axes.
Args:
x: The main model input.
cond: A dictionary of conditional inputs. Those with keys that start with
"channel:" are processed here while all others are omitted.
Returns:
Model input merged with channel conditions.
"""

out_spatial_shape = x.shape[-3:-1]
proc_cond = {}
for key, value in cond.items():
if value.shape[-3:-1] == out_spatial_shape:
continue
proc_cond[key] = Axial2DMLP(out_dims=value.shape[-3:-1])(value)

merge_channel_cond = InterpConvMerge(
embed_dim=self.embed_dim,
kernel_size=self.kernel_size,
resize_method=self.resize_method,
padding=self.padding,
)
return merge_channel_cond(x, proc_cond)


class DStack(nn.Module):
"""Downsampling stack.
Expand Down Expand Up @@ -470,6 +540,7 @@ class UNet(nn.Module):
num_heads: int = 8
cond_resize_method: str = "bilinear"
cond_embed_dim: int = 128
cond_merging_fn: type[MergeChannelCond] = InterpConvMerge

@nn.compact
def __call__(
Expand Down Expand Up @@ -514,7 +585,7 @@ def __call__(

kernel_dim = x.ndim - 2
cond = {} if cond is None else cond
x = MergeChannelCond(
x = self.cond_merging_fn(
embed_dim=self.cond_embed_dim,
resize_method=self.cond_resize_method,
kernel_size=(3,) * kernel_dim,
Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/lib/diffusion/unets3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def __call__(
)(x)

cond = {} if cond is None else cond
x = unets.MergeChannelCond(
x = unets.InterpConvMerge(
embed_dim=self.cond_embed_dim,
resize_method=self.cond_resize_method,
kernel_size=(3, 3),
Expand Down
42 changes: 37 additions & 5 deletions swirl_dynamics/lib/diffusion/unets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,52 @@ def test_channelwise_conditioning_output_shape(self, x_dims, c_dims):
)
# Check shape dict so that err message is easier to read when things break.
shape_dict = jax.tree_map(jnp.shape, variables["params"])
self.assertIn("MergeChannelCond_0", shape_dict)
self.assertIn("InterpConvMerge_0", shape_dict)
# First condition should be reshaped. Second one (correct shape) shoud not.
self.assertIn(
"conv2d_embed_channel:cond1", shape_dict["MergeChannelCond_0"]
)
self.assertIn("conv2d_embed_channel:cond1", shape_dict["InterpConvMerge_0"])
self.assertNotIn(
"conv2d_embed_channel:cond2", shape_dict["MergeChannelCond_0"]
"conv2d_embed_channel:cond2", shape_dict["InterpConvMerge_0"]
)

out = jax.jit(functools.partial(model.apply, is_training=True))(
variables, x, sigma, cond
)
self.assertEqual(out.shape, x.shape)

def test_preconditioned_merging_functions(self):
x_dims = (1, 16, 8, 3)
c_dims = (1, 8, 4, 6)
x = jax.random.normal(jax.random.PRNGKey(42), x_dims)
cond = {
"channel:cond1": jax.random.normal(jax.random.PRNGKey(42), c_dims),
}
sigma = jnp.array(0.5)
model = diffusion.unets.PreconditionedDenoiser(
out_channels=x_dims[-1],
num_channels=(4, 8, 12),
downsample_ratio=(2, 2, 2),
num_blocks=2,
num_heads=4,
sigma_data=1.0,
use_position_encoding=False,
cond_embed_dim=32,
cond_resize_method="cubic",
cond_merging_fn=diffusion.unets.AxialMLPInterpConvMerge,
)
variables = model.init(
jax.random.PRNGKey(42), x=x, sigma=sigma, cond=cond, is_training=True
)
# Check shape dict so that err message is easier to read when things break.
shape_dict = jax.tree_map(jnp.shape, variables["params"])
self.assertIn("AxialMLPInterpConvMerge_0", shape_dict)
self.assertIn("Axial2DMLP_0", shape_dict["AxialMLPInterpConvMerge_0"])
self.assertIn("InterpConvMerge_0", shape_dict["AxialMLPInterpConvMerge_0"])

out = jax.jit(functools.partial(model.apply, is_training=True))(
variables, x, sigma, cond
)
self.assertEqual(out.shape, x.shape)


if __name__ == "__main__":
absltest.main()
2 changes: 1 addition & 1 deletion swirl_dynamics/lib/diffusion/vivit_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def __call__(
batch_size_input, num_frames, height, width, _ = x.shape

cond = {} if cond is None else cond
x = unets.MergeChannelCond(
x = unets.InterpConvMerge(
embed_dim=self.cond_embed_dim,
resize_method=self.cond_resize_method,
kernel_size=self.cond_kernel_size,
Expand Down

0 comments on commit b4757fd

Please sign in to comment.