Skip to content

Commit

Permalink
chore: initial fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Oct 7, 2024
1 parent 55abdee commit d91da33
Showing 1 changed file with 41 additions and 62 deletions.
103 changes: 41 additions & 62 deletions jflux/model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from dataclasses import dataclass

import jax.dtypes
import jax
import jax.numpy as jnp
from chex import Array
from einops import rearrange
from flax import nnx
from jax import numpy as jnp
from jax.typing import DTypeLike

from jflux.modules.layers import (
AdaLayerNorm,
Embed,
Identity,
from flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock
from jax.typing import DTypeLike


@dataclass
Expand All @@ -29,6 +30,13 @@ class FluxParams:
theta: int
qkv_bias: bool
guidance_embed: bool
rngs: nnx.Rngs
param_dtype: DTypeLike


class Identity(nnx.Module):
def __call__(self, x: Array) -> Array:
return x


class Flux(nnx.Module):
Expand All @@ -37,7 +45,6 @@ class Flux(nnx.Module):
"""

def __init__(self, params: FluxParams):
# Note: no super().__init__() call for nnx
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
Expand All @@ -52,87 +59,59 @@ def __init__(self, params: FluxParams):
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = Embed(
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nnx.Linear(
self.in_channels,
self.hidden_size,
in_features=self.in_channels,
out_features=self.hidden_size,
use_bias=True,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=rngs,
)
self.time_in = MLPEmbedder(
in_dim=256,
hidden_dim=self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.vector_in = MLPEmbedder(
params.vec_in_dim,
self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(
in_dim=256,
hidden_dim=self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if params.guidance_embed
else Identity()
else Identity(rngs=params.rngs, param_dtype=params.param_dtype)
)
self.txt_in = nnx.Linear(
params.context_in_dim,
self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
in_features=params.context_in_dim,
out_features=self.hidden_size,
use_bias=True,
rngs=params.rngs,
param_dtype=params.param_dtype,
)

self.double_blocks = nnx.Sequential(
self.double_blocks = nnx.ModuleList(
*[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth)
]
)

self.single_blocks = nnx.Sequential(
self.single_blocks = nnx.ModuleList(
*[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth_single_blocks)
]
)

self.final_layer = AdaLayerNorm(
self.hidden_size,
1,
self.out_channels,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

def __call__(
self,
Expand All @@ -155,17 +134,17 @@ def __call__(
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) # type: ignore
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)

ids = jnp.concat((txt_ids, img_ids), axis=1)
ids = jnp.concatenate((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)

img = jnp.concat((txt, img), axis=1)
img = jnp.concatenate((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
Expand Down

0 comments on commit d91da33

Please sign in to comment.