Skip to content

Commit

Permalink
Fix the model (#9)
Browse files Browse the repository at this point in the history
Co-authored-by: Saurav Maheshkar <[email protected]>
  • Loading branch information
ariG23498 and SauravMaheshkar authored Oct 7, 2024
1 parent e43da22 commit e9dbdf7
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 445 deletions.
92 changes: 37 additions & 55 deletions jflux/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from dataclasses import dataclass

import jax.dtypes
import jax.numpy as jnp
from chex import Array
from flax import nnx
from jax import numpy as jnp
from flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from jax.typing import DTypeLike

from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock
from jflux.modules.layers import AdaLayerNorm, Embed, Identity, timestep_embedding


@dataclass
class FluxParams:
Expand All @@ -24,6 +28,13 @@ class FluxParams:
theta: int
qkv_bias: bool
guidance_embed: bool
rngs: nnx.Rngs
param_dtype: DTypeLike


class Identity(nnx.Module):
def __call__(self, x: Array) -> Array:
return x


class Flux(nnx.Module):
Expand All @@ -32,7 +43,6 @@ class Flux(nnx.Module):
"""

def __init__(self, params: FluxParams):
# Note: no super().__init__() call for nnx
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
Expand All @@ -47,48 +57,29 @@ def __init__(self, params: FluxParams):
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = Embed(
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nnx.Linear(
self.in_channels,
self.hidden_size,
in_features=self.in_channels,
out_features=self.hidden_size,
use_bias=True,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=rngs,
)
self.time_in = MLPEmbedder(
in_dim=256,
hidden_dim=self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.vector_in = MLPEmbedder(
params.vec_in_dim,
self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(
in_dim=256,
hidden_dim=self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if params.guidance_embed
else Identity()
)
self.txt_in = nnx.Linear(
params.context_in_dim,
self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
in_features=params.context_in_dim,
out_features=self.hidden_size,
use_bias=True,
rngs=params.rngs,
param_dtype=params.param_dtype,
)

self.double_blocks = nnx.Sequential(
Expand All @@ -98,9 +89,8 @@ def __init__(self, params: FluxParams):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth)
]
Expand All @@ -112,22 +102,14 @@ def __init__(self, params: FluxParams):
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth_single_blocks)
]
)

self.final_layer = AdaLayerNorm(
self.hidden_size,
1,
self.out_channels,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

def __call__(
self,
Expand All @@ -150,17 +132,17 @@ def __call__(
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) # type: ignore
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)

ids = jnp.concat((txt_ids, img_ids), axis=1)
ids = jnp.concatenate((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)

img = jnp.concat((txt, img), axis=1)
img = jnp.concatenate((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
Expand Down
Loading

0 comments on commit e9dbdf7

Please sign in to comment.