Skip to content

Commit

Permalink
fix: hdit out_channel
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 25, 2024
1 parent 6a65ac1 commit 84473fc
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
4 changes: 2 additions & 2 deletions models/diffusion_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def define_G(
model = HDiT(
levels=hdit_config.levels,
mapping=hdit_config.mapping,
in_channels=in_channel,
out_channels=model_output_nc,
in_channel=in_channel,
out_channel=model_output_nc,
patch_size=hdit_config.patch_size,
num_classes=0,
mapping_cond_dim=0,
Expand Down
8 changes: 4 additions & 4 deletions models/gan_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def define_G(
net = HDiT(
levels=hdit_config.levels,
mapping=hdit_config.mapping,
in_channels=model_input_nc,
out_channels=model_output_nc,
in_channel=model_input_nc,
out_channel=model_output_nc,
patch_size=hdit_config.patch_size,
last_zero_init=False,
num_classes=0,
Expand All @@ -262,8 +262,8 @@ def define_G(
return net
elif G_netG == "img2img_turbo":
net = Img2ImgTurbo(
in_channels=model_input_nc,
out_channels=model_output_nc,
in_channel=model_input_nc,
out_channel=model_output_nc,
lora_rank_unet=G_lora_unet,
lora_rank_vae=G_lora_vae,
)
Expand Down
11 changes: 5 additions & 6 deletions models/modules/hdit/hdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,8 @@ def __init__(
self,
levels,
mapping,
in_channels,
out_channels,
in_channel,
out_channel,
patch_size,
last_zero_init=True,
num_classes=0,
Expand All @@ -752,7 +752,8 @@ def __init__(
):
super().__init__()
self.num_classes = num_classes
self.patch_in = TokenMerge(in_channels, levels[0].width, patch_size)
self.out_channel = out_channel
self.patch_in = TokenMerge(in_channel, levels[0].width, patch_size)
self.mapping = tag_module(
MappingNetwork(
mapping.depth, mapping.width, mapping.d_ff, dropout=mapping.dropout
Expand Down Expand Up @@ -820,9 +821,7 @@ def __init__(
)

self.out_norm = RMSNorm(levels[0].width)
self.patch_out = TokenSplitWithoutSkip(
levels[0].width, out_channels, patch_size
)
self.patch_out = TokenSplitWithoutSkip(levels[0].width, out_channel, patch_size)
if last_zero_init:
nn.init.zeros_(self.patch_out.proj.weight)

Expand Down
2 changes: 1 addition & 1 deletion models/modules/img2img_turbo/img2img_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def my_vae_decoder_fwd(self, sample, latent_embeds=None):


class Img2ImgTurbo(nn.Module):
def __init__(self, in_channels, out_channels, lora_rank_unet, lora_rank_vae):
def __init__(self, in_channel, out_channel, lora_rank_unet, lora_rank_vae):
super().__init__()

# TODO: other params
Expand Down

0 comments on commit 84473fc

Please sign in to comment.