Skip to content

Commit

Permalink
fix port
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Oct 8, 2024
1 parent bce2003 commit b5cdd35
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions jflux/port.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def port_autoencoder(autoencoder, tensors):


def port_linear(linear, tensors, prefix):
linear.kernel.value = rearrange(tensors[f"{prefix}.weight"])
linear.bias.value = rearrange(tensors[f"{prefix}.bias"])
linear.kernel.value = rearrange(tensors[f"{prefix}.weight"], "i o -> o i")
linear.bias.value = tensors[f"{prefix}.bias"]
return linear


Expand Down Expand Up @@ -333,7 +333,7 @@ def port_double_stream_block(double_stream_block, tensors, prefix):
double_stream_block.img_attn = port_self_attention(
self_attention=double_stream_block.img_attn,
tensors=tensors,
prefix="{prefix}.img_attn",
prefix=f"{prefix}.img_attn",
)

# double_stream_block.img_norm2 has no params
Expand All @@ -360,7 +360,7 @@ def port_double_stream_block(double_stream_block, tensors, prefix):
double_stream_block.txt_attn = port_self_attention(
self_attention=double_stream_block.txt_attn,
tensors=tensors,
prefix="{prefix}.txt_attn",
prefix=f"{prefix}.txt_attn",
)

# double_stream_block.txt_norm2 has no params
Expand All @@ -381,22 +381,22 @@ def port_double_stream_block(double_stream_block, tensors, prefix):

def port_single_stream_block(single_stream_block, tensors, prefix):
single_stream_block.linear1 = port_linear(
linear=single_stream_block.linear1, tensors=tensors, prefix="{prefix}.linear1"
linear=single_stream_block.linear1, tensors=tensors, prefix=f"{prefix}.linear1"
)
single_stream_block.linear2 = port_linear(
linear=single_stream_block.linear2, tensors=tensors, prefix="{prefix}.linear2"
linear=single_stream_block.linear2, tensors=tensors, prefix=f"{prefix}.linear2"
)

single_stream_block.norm = port_qk_norm(
qk_norm=single_stream_block.norm, tensors=tensors, prefix="{prefix}.norm"
qk_norm=single_stream_block.norm, tensors=tensors, prefix=f"{prefix}.norm"
)

# single_stream_block.pre_norm has no params

single_stream_block.modulation = port_modulation(
modulation=single_stream_block.modulation,
tensors=tensors,
prefix="{prefix}.modulation",
prefix=f"{prefix}.modulation",
)

return single_stream_block
Expand All @@ -413,21 +413,21 @@ def port_mlp_embedder(mlp_embedder, tensors, prefix):
return mlp_embedder


def port_last_layer(last_layer, tensors, prefix):
def port_final_layer(final_layer, tensors, prefix):
# last_layer.norm_final has no params
last_layer.linear = port_linear(
linear=last_layer.linear,
final_layer.linear = port_linear(
linear=final_layer.linear,
tensors=tensors,
prefix=f"{prefix}.linear",
)

last_layer.adaLN_modulation.layers[1] = port_linear(
linear=last_layer.adaLN_modulation.layers[1],
final_layer.adaLN_modulation.layers[1] = port_linear(
linear=final_layer.adaLN_modulation.layers[1],
tensors=tensors,
prefix=f"{prefix}.adaLN_modulation.1",
)

return last_layer
return final_layer


def port_flux(flux, tensors):
Expand Down Expand Up @@ -476,8 +476,8 @@ def port_flux(flux, tensors):
prefix=f"single_blocks.{i}",
)

flux.last_layer = port_last_layer(
last_layer=flux.last_layer,
flux.final_layer = port_final_layer(
final_layer=flux.final_layer,
tensors=tensors,
prefix="last_layer",
prefix="final_layer",
)

0 comments on commit b5cdd35

Please sign in to comment.