From d91da3334f4fd04ce355b69d03504935a2bba620 Mon Sep 17 00:00:00 2001 From: aritG23498 Date: Mon, 7 Oct 2024 14:13:33 +0000 Subject: [PATCH 1/2] chore: initial fix --- jflux/model.py | 103 ++++++++++++++++++++----------------------------- 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/jflux/model.py b/jflux/model.py index bd02bac..1317c94 100644 --- a/jflux/model.py +++ b/jflux/model.py @@ -1,18 +1,19 @@ from dataclasses import dataclass -import jax.dtypes +import jax +import jax.numpy as jnp from chex import Array +from einops import rearrange from flax import nnx -from jax import numpy as jnp -from jax.typing import DTypeLike - -from jflux.modules.layers import ( - AdaLayerNorm, - Embed, - Identity, +from flux.modules.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, timestep_embedding, ) -from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock +from jax.typing import DTypeLike @dataclass @@ -29,6 +30,13 @@ class FluxParams: theta: int qkv_bias: bool guidance_embed: bool + rngs: nnx.Rngs + param_dtype: DTypeLike + + +class Identity(nnx.Module): + def __call__(self, x: Array) -> Array: + return x class Flux(nnx.Module): @@ -37,7 +45,6 @@ class Flux(nnx.Module): """ def __init__(self, params: FluxParams): - # Note: no super().__init__() call for nnx self.params = params self.in_channels = params.in_channels self.out_channels = self.in_channels @@ -52,87 +59,59 @@ def __init__(self, params: FluxParams): ) self.hidden_size = params.hidden_size self.num_heads = params.num_heads - self.pe_embedder = Embed( + self.pe_embedder = EmbedND( dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim ) self.img_in = nnx.Linear( - self.in_channels, - self.hidden_size, + in_features=self.in_channels, + out_features=self.hidden_size, use_bias=True, - dtype=self.dtype, - param_dtype=self.param_dtype, - rngs=rngs, - ) - self.time_in = MLPEmbedder( - in_dim=256, - hidden_dim=self.hidden_size, - rngs=rngs, - dtype=self.dtype, - param_dtype=self.param_dtype, - ) - self.vector_in = MLPEmbedder( - params.vec_in_dim, - self.hidden_size, - rngs=rngs, - dtype=self.dtype, - param_dtype=self.param_dtype, + rngs=params.rngs, + param_dtype=params.param_dtype, ) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) self.guidance_in = ( - MLPEmbedder( - in_dim=256, - hidden_dim=self.hidden_size, - rngs=rngs, - dtype=self.dtype, - param_dtype=self.param_dtype, - ) + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed - else Identity() + else Identity(rngs=params.rngs, param_dtype=params.param_dtype) ) self.txt_in = nnx.Linear( - params.context_in_dim, - self.hidden_size, - rngs=rngs, - dtype=self.dtype, - param_dtype=self.param_dtype, + in_features=params.context_in_dim, + out_features=self.hidden_size, + use_bias=True, + rngs=params.rngs, + param_dtype=params.param_dtype, ) - self.double_blocks = nnx.Sequential( + self.double_blocks = nnx.ModuleList( *[ DoubleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, - rngs=rngs, - dtype=self.dtype, - param_dtype=self.param_dtype, + rngs=params.rngs, + param_dtype=params.param_dtype, ) for _ in range(params.depth) ] ) - self.single_blocks = nnx.Sequential( + self.single_blocks = nnx.ModuleList( *[ SingleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, - rngs=rngs, - dtype=self.dtype, - param_dtype=self.param_dtype, + rngs=params.rngs, + param_dtype=params.param_dtype, ) for _ in range(params.depth_single_blocks) ] ) - self.final_layer = AdaLayerNorm( - self.hidden_size, - 1, - self.out_channels, - rngs=rngs, - dtype=self.dtype, - param_dtype=self.param_dtype, - ) + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) def __call__( self, @@ -155,17 +134,17 @@ def __call__( raise ValueError( "Didn't get guidance strength for guidance distilled model." ) - vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) # type: ignore + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) vec = vec + self.vector_in(y) txt = self.txt_in(txt) - ids = jnp.concat((txt_ids, img_ids), axis=1) + ids = jnp.concatenate((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe) - img = jnp.concat((txt, img), axis=1) + img = jnp.concatenate((txt, img), 1) for block in self.single_blocks: img = block(img, vec=vec, pe=pe) img = img[:, txt.shape[1] :, ...] From 7b98927c4820a4efab8b3b3ccd22d75299666e31 Mon Sep 17 00:00:00 2001 From: aritG23498 Date: Mon, 7 Oct 2024 18:20:09 +0000 Subject: [PATCH 2/2] chore: adding more tests --- jflux/model.py | 8 +- jflux/modules.py | 390 ------------------------------------------ tests/test_model.py | 40 +++++ tests/test_modules.py | 82 --------- 4 files changed, 43 insertions(+), 477 deletions(-) delete mode 100644 jflux/modules.py create mode 100644 tests/test_model.py delete mode 100644 tests/test_modules.py diff --git a/jflux/model.py b/jflux/model.py index 1317c94..acc6f32 100644 --- a/jflux/model.py +++ b/jflux/model.py @@ -1,9 +1,7 @@ from dataclasses import dataclass -import jax import jax.numpy as jnp from chex import Array -from einops import rearrange from flax import nnx from flux.modules.layers import ( DoubleStreamBlock, @@ -74,7 +72,7 @@ def __init__(self, params: FluxParams): self.guidance_in = ( MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed - else Identity(rngs=params.rngs, param_dtype=params.param_dtype) + else Identity() ) self.txt_in = nnx.Linear( in_features=params.context_in_dim, @@ -84,7 +82,7 @@ def __init__(self, params: FluxParams): param_dtype=params.param_dtype, ) - self.double_blocks = nnx.ModuleList( + self.double_blocks = nnx.Sequential( *[ DoubleStreamBlock( self.hidden_size, @@ -98,7 +96,7 @@ def __init__(self, params: FluxParams): ] ) - self.single_blocks = nnx.ModuleList( + self.single_blocks = nnx.Sequential( *[ SingleStreamBlock( self.hidden_size, diff --git a/jflux/modules.py b/jflux/modules.py deleted file mode 100644 index f4be9ad..0000000 --- a/jflux/modules.py +++ /dev/null @@ -1,390 +0,0 @@ -import typing -from dataclasses import dataclass - -import jax -import jax.numpy as jnp -from chex import Array -from einops import rearrange -from flax import nnx -from jax.typing import DTypeLike - -from jflux.layers import QKNorm -from jflux.math import attention - - -class MLPEmbedder(nnx.Module): - """ - MLP embedder with a single hidden layer and SiLU activation. - - Args: - in_dim (int): Input dimension. - hidden_dim (int): Hidden dimension. - rngs (nnx.Rngs): RNGs for the layer. - dtype (DTypeLike): Data type for the layer. - param_dtype (DTypeLike): Parameter data type for the layer. - """ - - def __init__( - self, - in_dim: int, - hidden_dim: int, - rngs: nnx.Rngs, - dtype: DTypeLike = jax.dtypes.bfloat16, - param_dtype: DTypeLike = None, - ) -> None: - if param_dtype is None: - param_dtype = dtype - - self.in_layer = nnx.Linear( - in_features=in_dim, - out_features=hidden_dim, - dtype=dtype, - param_dtype=param_dtype, - use_bias=True, - rngs=rngs, - ) - self.out_layer = nnx.Linear( - in_features=hidden_dim, - out_features=hidden_dim, - dtype=dtype, - param_dtype=param_dtype, - use_bias=True, - rngs=rngs, - ) - - def __call__(self, x: Array) -> Array: - return self.out_layer(nnx.silu(self.in_layer(x))) - - -class SelfAttention(nnx.Module): - """ - Self-attention module with QKV linear layers and a projection layer. - - Args: - dim (int): Dimension of the input. - rngs (nnx.Rngs): RNGs for the layer. - num_heads (int): Number of attention heads. - qkv_bias (bool): Whether to use bias in QKV linear layers. - dtype (DTypeLike): Data type for the layer. - param_dtype (DTypeLike): Parameter data type for the layer. - """ - - def __init__( - self, - dim: int, - rngs: nnx.Rngs, - num_heads: int = 8, - qkv_bias: bool = False, - dtype: DTypeLike = jax.dtypes.bfloat16, - param_dtype: DTypeLike = None, - ) -> None: - self.num_heads = num_heads - head_dim = dim // num_heads - - self.qkv = nnx.Linear( - in_features=dim, - out_features=dim * 3, - use_bias=qkv_bias, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ) - self.norm = QKNorm(head_dim, rngs=rngs, dtype=dtype, param_dtype=param_dtype) - self.proj = nnx.Linear( - in_features=dim, - out_features=dim, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ) - - def __call__(self, x: Array, pe: Array) -> Array: - qkv = self.qkv(x) - q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) - q, k = self.norm(q, k, v) - x = attention(q, k, v, pe=pe) - x = self.proj(x) - return x - - -# TODO (SauravMaheshkar): use `chex.dataclass` -@dataclass -class ModulationOut: - shift: Array - scale: Array - gate: Array - - -class Modulation(nnx.Module): - """ - Modulation module with a linear layer and split output. - - Args: - dim (int): Dimension of the input. - double (bool): Whether to split the output into two parts. - rngs (nnx.Rngs): RNGs for the layer. - dtype (DTypeLike): Data type for the layer. - param_dtype (DTypeLike): Parameter data type for the layer. - """ - - def __init__( - self, - dim: int, - double: bool, - rngs: nnx.Rngs, - dtype: DTypeLike = jax.dtypes.bfloat16, - param_dtype: DTypeLike = None, - ) -> None: - if param_dtype is None: - param_dtype = dtype - self.is_double = double - self.multiplier = 6 if double else 3 - self.lin = nnx.Linear( - dim, - self.multiplier * dim, - use_bias=True, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ) - - def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]: - ary = self.lin(nnx.silu(vec))[:, None, :] - out = jnp.split(ary, self.multiplier, axis=-1) - - return ( - ModulationOut(*out[:3]), - ModulationOut(*out[3:]) if self.is_double else None, - ) - - -class DoubleStreamBlock(nnx.Module): - """ - Custom Module for a DoubleStreamBlock. - - Args: - hidden_size (int): Dimension of the hidden layer. - num_heads (int): Number of attention heads. - mlp_ratio (float): Ratio of hidden layer to mlp hidden layer. - rngs (nnx.Rngs): RNGs for the layer. - qkv_bias (bool): Whether to use bias in QKV linear layers. - dtype (DTypeLike): Data type for the layer. - param_dtype (DTypeLike): Parameter data type for the layer. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - mlp_ratio: float, - rngs: nnx.Rngs, - qkv_bias: bool = False, - dtype: DTypeLike = jax.dtypes.bfloat16, - param_dtype: DTypeLike = None, - ): - if param_dtype is None: - param_dtype = dtype - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.num_heads = num_heads - self.hidden_size = hidden_size - self.img_mod = Modulation( - hidden_size, double=True, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) - self.img_norm1 = nnx.LayerNorm( - hidden_size, epsilon=1e-6, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) - self.img_attn = SelfAttention( - dim=hidden_size, - num_heads=num_heads, - qkv_bias=qkv_bias, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ) - - self.img_norm2 = nnx.LayerNorm( - hidden_size, epsilon=1e-6, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) - self.img_mlp = nnx.Sequential( - nnx.Linear( - hidden_size, - mlp_hidden_dim, - use_bias=True, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ), - nnx.gelu, - nnx.Linear( - mlp_hidden_dim, - hidden_size, - use_bias=True, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ), - ) - - self.txt_mod = Modulation( - hidden_size, double=True, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) - self.txt_norm1 = nnx.LayerNorm( - hidden_size, epsilon=1e-6, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) - self.txt_attn = SelfAttention( - dim=hidden_size, - num_heads=num_heads, - qkv_bias=qkv_bias, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ) - - self.txt_norm2 = nnx.LayerNorm( - hidden_size, epsilon=1e-6, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) - self.txt_mlp = nnx.Sequential( - nnx.Linear( - hidden_size, - mlp_hidden_dim, - use_bias=True, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ), - nnx.gelu, - nnx.Linear( - mlp_hidden_dim, - hidden_size, - use_bias=True, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ), - ) - - @typing.no_type_check - def __call__( - self, img: Array, txt: Array, vec: Array, pe: Array - ) -> tuple[Array, Array]: - img_mod1, img_mod2 = self.img_mod(vec) - txt_mod1, txt_mod2 = self.txt_mod(vec) - - # prepare image for attention - img_modulated = self.img_norm1(img) - img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift - img_qkv = self.img_attn.qkv(img_modulated) - img_q, img_k, img_v = rearrange( - img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads - ) - img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) - - # prepare txt for attention - txt_modulated = self.txt_norm1(txt) - txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift - txt_qkv = self.txt_attn.qkv(txt_modulated) - txt_q, txt_k, txt_v = rearrange( - txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads - ) - txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - - # run actual attention - q = jnp.concat((txt_q, img_q), axis=2) - k = jnp.concat((txt_k, img_k), axis=2) - v = jnp.concat((txt_v, img_v), axis=2) - - attn = attention(q, k, v, pe=pe) - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - - # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp( - (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift - ) - - # calculate the txt bloks - txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt = txt + txt_mod2.gate * self.txt_mlp( - (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift - ) - return img, txt - - -class SingleStreamBlock(nnx.Module): - """ - A DiT block with parallel linear layers as described in - https://arxiv.org/abs/2302.05442 and adapted modulation interface. - - Args: - hidden_size (int): Dimension of the hidden layer. - num_heads (int): Number of attention heads. - rngs (nnx.Rngs): RNGs for the layer. - mlp_ratio (float): Ratio of hidden layer to mlp hidden layer. - qk_scale (float): Scaling factor for query and key. - dtype (DTypeLike): Data type for the layer. - param_dtype (DTypeLike): Parameter data type for the layer. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - rngs: nnx.Rngs, - mlp_ratio: float = 4.0, - qk_scale: float | None = None, - dtype: DTypeLike = jax.dtypes.bfloat16, - param_dtype: DTypeLike = None, - ): - if param_dtype is None: - param_dtype = dtype - self.hidden_dim = hidden_size - self.num_heads = num_heads - head_dim = hidden_size // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.mlp_hidden_dim = int(hidden_size * mlp_ratio) - # qkv and mlp_in - self.linear1 = nnx.Linear( - hidden_size, - hidden_size * 3 + self.mlp_hidden_dim, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ) - # proj and mlp_out - self.linear2 = nnx.Linear( - hidden_size + self.mlp_hidden_dim, - hidden_size, - rngs=rngs, - dtype=dtype, - param_dtype=param_dtype, - ) - - self.norm = QKNorm(head_dim, rngs=rngs, dtype=dtype, param_dtype=param_dtype) - - self.hidden_size = hidden_size - self.pre_norm = nnx.LayerNorm( - hidden_size, epsilon=1e-6, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) - - self.mlp_act = nnx.gelu - self.modulation = Modulation( - hidden_size, double=False, rngs=rngs, dtype=dtype, param_dtype=param_dtype - ) - - def __call__(self, x: Array, vec: Array, pe: Array) -> Array: - mod, _ = self.modulation(vec) - x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift - qkv, mlp = jnp.split( - self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], axis=-1 - ) - - q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) - q, k = self.norm(q, k, v) - - # compute attention - attn = attention(q, k, v, pe=pe) - # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(jnp.concat((attn, self.mlp_act(mlp)), axis=2)) - return x + mod.gate * output diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..2823cba --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,40 @@ +import numpy as np +from jax import numpy as jnp +from flax import nnx +from jflux.model import FluxParams, Flux + +class ModelTestCase(np.testing.TestCase): + def test_model(self): + # Initialize + in_channels=64 + vec_in_dim=768 + context_in_dim=4096 + hidden_size=3072 + mlp_ratio=4.0 + num_heads=24 + depth=19 + depth_single_blocks=38 + axes_dim=[16, 56, 56] + theta=10_000 + qkv_bias=True + guidance_embed=False + rngs=nnx.Rngs(default=42) + param_dtype=jnp.float32 + + flux_params = FluxParams( + in_channels=in_channels, + vec_in_dim=vec_in_dim, + context_in_dim=context_in_dim, + hidden_size=hidden_size, + mlp_ratio=mlp_ratio, + num_heads=num_heads, + depth=depth, + depth_single_blocks=depth_single_blocks, + axes_dim=axes_dim, + theta=theta, + qkv_bias=qkv_bias, + guidance_embed=guidance_embed, + rngs=rngs, + param_dtype=param_dtype, + ) + flux = Flux(params=flux_params) \ No newline at end of file diff --git a/tests/test_modules.py b/tests/test_modules.py deleted file mode 100644 index e2b5878..0000000 --- a/tests/test_modules.py +++ /dev/null @@ -1,82 +0,0 @@ -import chex -import jax.numpy as jnp -import pytest -import torch -from flax import nnx -from flux.modules.layers import MLPEmbedder -from flux.modules.layers import Modulation as PytorchModulation -from flux.modules.layers import SelfAttention as PytorchSelfAttention - -from jflux.modules import MLPEmbedder as JaxMLPEmbedder -from jflux.modules import Modulation as JaxModulation -from jflux.modules import SelfAttention as JaxSelfAttention -from tests.utils import torch2jax - - -class ModulesTestCase(chex.TestCase): - def test_mlp_embedder(self): - # Initialize layers - pytorch_mlp_embedder = MLPEmbedder(in_dim=512, hidden_dim=256) - jax_mlp_embedder = JaxMLPEmbedder( - in_dim=512, hidden_dim=256, rngs=nnx.Rngs(default=42), dtype=jnp.float32 - ) - - # Generate random inputs - torch_input = torch.randn(1, 32, 512, dtype=torch.float32) - jax_input = torch2jax(torch_input) - - # Forward pass - jax_output = jax_mlp_embedder(jax_input) - pytorch_output = pytorch_mlp_embedder(torch_input) - - # Assertions - chex.assert_equal_shape([jax_output, torch2jax(pytorch_output)]) - - @pytest.mark.skip(reason="Blocked by apply_rope") - def test_self_attention(self): - # Initialize layers - pytorch_self_attention = PytorchSelfAttention(dim=512) - jax_self_attention = JaxSelfAttention( - dim=512, rngs=nnx.Rngs(default=42), dtype=jnp.float32 - ) - - # Generate random inputs - torch_input = torch.randn(1, 32, 512, dtype=torch.float32) - torch_pe = torch.randn(1, 32, 512, dtype=torch.float32) - jax_input = torch2jax(torch_input) - jax_pe = torch2jax(torch_pe) - - # Forward pass - jax_output = jax_self_attention(jax_input, jax_pe) - pytorch_output = pytorch_self_attention(torch_input, torch_pe) - - # Assertions - chex.assert_equal_shape([jax_output, torch2jax(pytorch_output)]) - - def test_modulation(self): - # Initialize layers - pytorch_modulation = PytorchModulation(dim=512, double=True) - jax_modulation = JaxModulation( - dim=512, double=True, rngs=nnx.Rngs(default=42), dtype=jnp.float32 - ) - - # Generate random inputs - torch_input = torch.randn(1, 32, 512, dtype=torch.float32) - jax_input = torch2jax(torch_input) - - # Forward pass - jax_output = jax_modulation(jax_input) - pytorch_output = pytorch_modulation(torch_input) - - # Convert Modulation output to individual tensors - jax_tensors = [jax_output[0].shift, jax_output[0].scale, jax_output[0].gate] - torch_tensors = [ - torch2jax(pytorch_output[0].shift), - torch2jax(pytorch_output[0].scale), - torch2jax(pytorch_output[0].gate), - ] - - # Assertions - assert len(jax_output) == len(pytorch_output) - for i in range(len(jax_output)): - chex.assert_equal_shape([jax_tensors[i], torch_tensors[i]])