Skip to content

Commit

Permalink
adding porting code and fix layer norm
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Oct 8, 2024
1 parent 9162e4d commit cbdb372
Show file tree
Hide file tree
Showing 2 changed files with 375 additions and 118 deletions.
6 changes: 6 additions & 0 deletions jflux/modules/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(
self.img_norm1 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand All @@ -229,6 +230,7 @@ def __init__(
self.img_norm2 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand Down Expand Up @@ -257,6 +259,7 @@ def __init__(
self.txt_norm1 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand All @@ -272,6 +275,7 @@ def __init__(
self.txt_norm2 = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand Down Expand Up @@ -382,6 +386,7 @@ def __init__(
self.pre_norm = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand Down Expand Up @@ -419,6 +424,7 @@ def __init__(
self.norm_final = nnx.LayerNorm(
num_features=hidden_size,
use_scale=False,
use_bias=False,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
Expand Down
Loading

0 comments on commit cbdb372

Please sign in to comment.