Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the model #9

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 37 additions & 55 deletions jflux/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from dataclasses import dataclass

import jax.dtypes
import jax.numpy as jnp
from chex import Array
from flax import nnx
from jax import numpy as jnp
from flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from jax.typing import DTypeLike

from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock
from jflux.modules.layers import AdaLayerNorm, Embed, Identity, timestep_embedding


@dataclass
class FluxParams:
Expand All @@ -24,6 +28,13 @@ class FluxParams:
theta: int
qkv_bias: bool
guidance_embed: bool
rngs: nnx.Rngs
param_dtype: DTypeLike


class Identity(nnx.Module):
def __call__(self, x: Array) -> Array:
return x


class Flux(nnx.Module):
Expand All @@ -32,7 +43,6 @@ class Flux(nnx.Module):
"""

def __init__(self, params: FluxParams):
# Note: no super().__init__() call for nnx
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
Expand All @@ -47,48 +57,29 @@ def __init__(self, params: FluxParams):
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = Embed(
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nnx.Linear(
self.in_channels,
self.hidden_size,
in_features=self.in_channels,
out_features=self.hidden_size,
use_bias=True,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=rngs,
)
self.time_in = MLPEmbedder(
in_dim=256,
hidden_dim=self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.vector_in = MLPEmbedder(
params.vec_in_dim,
self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(
in_dim=256,
hidden_dim=self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if params.guidance_embed
else Identity()
)
self.txt_in = nnx.Linear(
params.context_in_dim,
self.hidden_size,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
in_features=params.context_in_dim,
out_features=self.hidden_size,
use_bias=True,
rngs=params.rngs,
param_dtype=params.param_dtype,
)

self.double_blocks = nnx.Sequential(
Expand All @@ -98,9 +89,8 @@ def __init__(self, params: FluxParams):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth)
]
Expand All @@ -112,22 +102,14 @@ def __init__(self, params: FluxParams):
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth_single_blocks)
]
)

self.final_layer = AdaLayerNorm(
self.hidden_size,
1,
self.out_channels,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

def __call__(
self,
Expand All @@ -150,17 +132,17 @@ def __call__(
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) # type: ignore
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)

ids = jnp.concat((txt_ids, img_ids), axis=1)
ids = jnp.concatenate((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)

img = jnp.concat((txt, img), axis=1)
img = jnp.concatenate((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
Expand Down
Loading
Loading