From 057c93cebc3c8ccac996bc5a5b49c88e3e39c4b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leonardo=20Zepeda-N=C3=BA=C3=B1ez?= Date: Thu, 29 Aug 2024 17:01:45 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 669111124 --- swirl_dynamics/lib/diffusion/unets.py | 159 ++++++++++++++++++ swirl_dynamics/lib/diffusion/unets_test.py | 41 +++++ .../rectified_flow/evaluation_main.py | 9 + .../rectified_flow/inference_main.py | 10 ++ .../rectified_flow/main_train_ens.py | 9 + 5 files changed, 228 insertions(+) diff --git a/swirl_dynamics/lib/diffusion/unets.py b/swirl_dynamics/lib/diffusion/unets.py index 34a0ee6..e530590 100644 --- a/swirl_dynamics/lib/diffusion/unets.py +++ b/swirl_dynamics/lib/diffusion/unets.py @@ -22,6 +22,7 @@ import numpy as np from swirl_dynamics.lib import layers + Array = jax.Array Initializer = nn.initializers.Initializer PrecisionLike = ( @@ -50,6 +51,7 @@ class AdaptiveScale(nn.Module): see e.g. https://arxiv.org/abs/2105.05233, and for the more general FiLM technique see https://arxiv.org/abs/1709.07871. """ + act_fun: Callable[[Array], Array] = nn.swish precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 @@ -452,6 +454,139 @@ def __call__(self, x: Array, cond: dict[str, Array]): return merge_channel_cond(x, proc_cond) +class MergeEmdCond(nn.Module): + """Base class for merging conditional inputs as embeddings.""" + + def __call__(self, emb: Array, cond: dict[str, Array], is_training: bool): + pass + + +class EmbConvMerge(MergeEmdCond): + """Compute conditional inputs through interpolation and convolutions. + + We resize the conditional inputs to match the spatial shape of the main input + and then pass them through a nonlinearity and a ConvLayer. The output is then + mixied with the embedding from the Fourier embedding. + + Attributes: + embed_dim: The output channel dimension. + latent_dim: The latent dimension of the embedding. + downsample_ratio: Ratio for the downsampling of the embedding. + interp_shape: The shape to which the conditional inputs are resized. + kernel_size: The convolutional kernel size. + resize_method: The interpolation method employed by `jax.image.resize`. + padding: The padding method of all convolutions. + num_heads: Number of heads in the attention block. + normalize_qk: Whether to normalize the query and key vectors in the + attention block. + """ + + embed_dim: int + latent_dim: int + kernel_size: Sequence[int] + downsample_ratio: Sequence[int] + interp_shape: Sequence[int] + resize_method: str = "cubic" + padding: str = "CIRCULAR" + num_heads: int = 128 + normalize_qk: bool = True + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, emb: Array, cond: dict[str, Array], is_training: bool): + """Merges conditional inputs along the channel dimension. + + Fields with spatial shape differing from the sample `x` are reshaped to + match it. Then, all conditional fields are passed through a nonlinearity, + a ConvLayer, and then concatenated with the main input along the last axis. + + Args: + emb: Embedding coming from the kernel_dim = x.ndim - 2. + cond: A dictionary of conditional inputs. Those with keys that start with + "channel:" are processed here while all others are omitted. + is_training: Whether the model is in training mode. + + Returns: + Embedding merged with channel conditions. + """ + + if emb.shape[-1] != self.embed_dim: + raise ValueError( + f"Number of channels in the embedding ({emb.shape[-1]}) must " + "match the number of channels in the output " + f"{self.embed_dim})." + ) + + value_temp = [] + + # Extract fields, resize and concatenate. + for key, value in sorted(cond.items()): + # TODO: Change the prefix to "merge_embed:". + if key.startswith("channel:"): + # Enforcing prefix in the key. + value = layers.FilteredResize( + output_size=self.interp_shape, + kernel_size=self.kernel_size, + method=self.resize_method, + padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + name=f"resize_embedding_{key}", + )(value) + + value_temp.append(value) + + value = jnp.concatenate(value_temp, axis=-1) + + kernel_dim = value.ndim - 2 + # Downsample the embedding. + num_levels = len(self.downsample_ratio) + for level in range(num_levels): + value = nn.swish(nn.LayerNorm()(value)) + value = layers.DownsampleConv( + features=self.latent_dim, + ratios=(self.downsample_ratio[level],) * kernel_dim, + kernel_init=default_init(1.0), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + name=f"level_{level}.embedding_downsample_conv", + )(value) + + # Add a self-attention block. + b, _, _, c = value.shape + value = AttentionBlock( + num_heads=self.num_heads, + precision=self.precision, + dtype=self.dtype, + normalize_qk=self.normalize_qk, + param_dtype=self.param_dtype, + name="cond_embedding.attention_block", + )(value.reshape(b, -1, c), is_training=is_training) + + value = nn.Dense( + features=self.embed_dim, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + )(value.reshape(b, -1)) + value = nn.swish(value) + + # Concatenate the noise and conditional embedding. + emb = jnp.concatenate([emb, value], axis=-1) + emb = nn.Dense( + features=self.embed_dim, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + )(emb) + + return emb + + class DStack(nn.Module): """Downsampling stack. @@ -682,6 +817,8 @@ class UNet(nn.Module): cond_resize_method: str = "bilinear" cond_embed_dim: int = 128 cond_merging_fn: type[MergeChannelCond] = InterpConvMerge + cond_embed_fn: type[nn.Module] | None = None + cond_embed_kwargs: dict[str, jax.typing.ArrayLike] | None = None precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype = jnp.float32 @@ -743,6 +880,28 @@ def __call__( )(x, cond) emb = FourierEmbedding(dims=self.noise_embed_dim)(sigma) + # Incorporating the embedding from the conditional inputs. + if self.cond_embed_fn: + if self.cond_embed_kwargs is None: + # For backward compatibility. + # TODO: Remove this once the configs are updated. + cond_embed_kwargs = dict(latent_dim=32, num_heads=32) + else: + cond_embed_kwargs = self.cond_embed_kwargs + + emb = self.cond_embed_fn( + embed_dim=self.noise_embed_dim, + latent_dim=cond_embed_kwargs["latent_dim"], + num_heads=cond_embed_kwargs["num_heads"], + kernel_size=(3,) * kernel_dim, + interp_shape=x.shape[:-1], + downsample_ratio=self.downsample_ratio, + padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + )(emb, cond, is_training=is_training) + skips = DStack( num_channels=self.num_channels, num_res_blocks=len(self.num_channels) * (self.num_blocks,), diff --git a/swirl_dynamics/lib/diffusion/unets_test.py b/swirl_dynamics/lib/diffusion/unets_test.py index 3833589..860acff 100644 --- a/swirl_dynamics/lib/diffusion/unets_test.py +++ b/swirl_dynamics/lib/diffusion/unets_test.py @@ -170,6 +170,47 @@ def test_preconditioned_merging_functions(self): ) self.assertEqual(out.shape, x.shape) + def test_preconditioned_embedding_functions(self): + x_dims = (1, 16, 8, 3) + c_dims = (1, 8, 4, 6) + x = jax.random.normal(jax.random.PRNGKey(42), x_dims) + cond = { + "channel:cond1": jax.random.normal(jax.random.PRNGKey(42), c_dims), + } + sigma = jnp.array(0.5) + model = diffusion.unets.PreconditionedDenoiser( + out_channels=x_dims[-1], + num_channels=(4, 8, 12), + downsample_ratio=(2, 2, 2), + num_blocks=2, + num_heads=4, + sigma_data=1.0, + use_position_encoding=False, + cond_embed_dim=32, + cond_resize_method="cubic", + cond_embed_fn=diffusion.unets.EmbConvMerge, + ) + variables = model.init( + jax.random.PRNGKey(42), x=x, sigma=sigma, cond=cond, is_training=True + ) + # Check shape dict so that err message is easier to read when things break. + shape_dict = jax.tree.map(jnp.shape, variables["params"]) + self.assertIn("EmbConvMerge_0", shape_dict) + self.assertIn( + "level_0.embedding_downsample_conv", shape_dict["EmbConvMerge_0"] + ) + self.assertIn( + "cond_embedding.attention_block", shape_dict["EmbConvMerge_0"] + ) + self.assertIn( + "resize_embedding_channel:cond1", shape_dict["EmbConvMerge_0"] + ) + + out = jax.jit(functools.partial(model.apply, is_training=True))( + variables, x, sigma, cond + ) + self.assertEqual(out.shape, x.shape) + if __name__ == "__main__": absltest.main() diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/evaluation_main.py b/swirl_dynamics/projects/debiasing/rectified_flow/evaluation_main.py index 32dbd5e..1818ca0 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/evaluation_main.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/evaluation_main.py @@ -27,6 +27,7 @@ from ml_collections import config_flags import numpy as np from swirl_dynamics.data import hdf5_utils +from swirl_dynamics.lib.diffusion import unets from swirl_dynamics.lib.solvers import ode as ode_solvers from swirl_dynamics.projects.debiasing.rectified_flow import data_utils from swirl_dynamics.projects.debiasing.rectified_flow import evaluation_metrics as metrics @@ -349,6 +350,13 @@ def read_normalized_stats( def build_model(config): """Builds the model from config file.""" + + if "conditional_embedding" in config and config.conditional_embedding: + logging.info("Using conditional embedding") + cond_embed_fn = unets.EmbConvMerge + else: + cond_embed_fn = None + flow_model = models.RescaledUnet( out_channels=config.out_channels, num_channels=config.num_channels, @@ -362,6 +370,7 @@ def build_model(config): use_position_encoding=config.use_position_encoding, num_heads=config.num_heads, normalize_qk=config.normalize_qk, + cond_embed_fn=cond_embed_fn, ) model = models.ConditionalReFlowModel( diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py b/swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py index 1bc2200..3b3e7a3 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/inference_main.py @@ -26,6 +26,7 @@ import ml_collections from ml_collections import config_flags import numpy as np +from swirl_dynamics.lib.diffusion import unets from swirl_dynamics.lib.solvers import ode as ode_solvers from swirl_dynamics.projects.debiasing.rectified_flow import data_utils from swirl_dynamics.projects.debiasing.rectified_flow import models @@ -251,6 +252,14 @@ def read_normalized_stats( def build_model(config): """Builds the model from config file.""" + + # Adding the conditional embedding for the FILM layer. + if "conditional_embedding" in config and config.conditional_embedding: + logging.info("Using conditional embedding") + cond_embed_fn = unets.EmbConvMerge + else: + cond_embed_fn = None + flow_model = models.RescaledUnet( out_channels=config.out_channels, num_channels=config.num_channels, @@ -263,6 +272,7 @@ def build_model(config): resize_to_shape=config.resize_to_shape, use_position_encoding=config.use_position_encoding, num_heads=config.num_heads, + cond_embed_fn=cond_embed_fn, normalize_qk=config.normalize_qk, ) diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/main_train_ens.py b/swirl_dynamics/projects/debiasing/rectified_flow/main_train_ens.py index 71688cf..d8a2b25 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/main_train_ens.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/main_train_ens.py @@ -25,6 +25,7 @@ from ml_collections import config_flags import optax from orbax import checkpoint +from swirl_dynamics.lib.diffusion import unets from swirl_dynamics.projects.debiasing.rectified_flow import data_utils from swirl_dynamics.projects.debiasing.rectified_flow import models from swirl_dynamics.projects.debiasing.rectified_flow import trainers @@ -239,6 +240,13 @@ def main(argv): dtype = jax.numpy.float32 param_dtype = jax.numpy.float32 + # Adding the conditional embedding for the FILM layer. + if "conditional_embedding" in config and config.conditional_embedding: + logging.info("Using conditional embedding") + cond_embed_fn = unets.EmbConvMerge + else: + cond_embed_fn = None + # Setting up the neural network for the flow model. flow_model = models.RescaledUnet( out_channels=config.out_channels, @@ -253,6 +261,7 @@ def main(argv): use_position_encoding=config.use_position_encoding, num_heads=config.num_heads, normalize_qk=config.normalize_qk, + cond_embed_fn=cond_embed_fn, dtype=dtype, param_dtype=param_dtype, )