Skip to content

Commit

Permalink
Adding FLUX porting code (#11)
Browse files Browse the repository at this point in the history
Co-authored-by: Saurav Maheshkar <[email protected]>
  • Loading branch information
ariG23498 and SauravMaheshkar authored Oct 9, 2024
1 parent 9162e4d commit 0c9fb04
Show file tree
Hide file tree
Showing 8 changed files with 400 additions and 146 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @SauravMaheshkar
* @SauravMaheshkar @ariG23498
5 changes: 2 additions & 3 deletions jflux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

import jax
import jax.numpy as jnp
from flax import nnx
from einops import rearrange
from fire import Fire
from flax import nnx
from jax.typing import DTypeLike

from PIL import Image

from einops import rearrange
from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5

Expand Down
4 changes: 2 additions & 2 deletions jflux/modules/conditioner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Note: This is a torch module not a Jax module
from torch import nn
from chex import Array
import jax.numpy as jnp
from chex import Array
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer


Expand Down
6 changes: 6 additions & 0 deletions jflux/modules/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(
self.img_norm1 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand All @@ -229,6 +230,7 @@ def __init__(
self.img_norm2 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand Down Expand Up @@ -257,6 +259,7 @@ def __init__(
self.txt_norm1 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand All @@ -272,6 +275,7 @@ def __init__(
self.txt_norm2 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand Down Expand Up @@ -382,6 +386,7 @@ def __init__(
self.pre_norm = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand Down Expand Up @@ -419,6 +424,7 @@ def __init__(
self.norm_final = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand Down
Loading

0 comments on commit 0c9fb04

Please sign in to comment.