From dc221c6c64719af1385bc3b346e832d18c802124 Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Fri, 20 Sep 2024 15:45:42 +0000 Subject: [PATCH] fix: hdit out_channel --- models/diffusion_networks.py | 4 ++-- models/gan_networks.py | 8 ++++---- models/modules/hdit/hdit.py | 11 +++++------ models/modules/img2img_turbo/img2img_turbo.py | 2 +- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/models/diffusion_networks.py b/models/diffusion_networks.py index 521fb4a68..355c4a4b3 100644 --- a/models/diffusion_networks.py +++ b/models/diffusion_networks.py @@ -226,8 +226,8 @@ def define_G( model = HDiT( levels=hdit_config.levels, mapping=hdit_config.mapping, - in_channels=in_channel, - out_channels=model_output_nc, + in_channel=in_channel, + out_channel=model_output_nc, patch_size=hdit_config.patch_size, num_classes=0, mapping_cond_dim=0, diff --git a/models/gan_networks.py b/models/gan_networks.py index a65da50b5..352e484eb 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -250,8 +250,8 @@ def define_G( net = HDiT( levels=hdit_config.levels, mapping=hdit_config.mapping, - in_channels=model_input_nc, - out_channels=model_output_nc, + in_channel=model_input_nc, + out_channel=model_output_nc, patch_size=hdit_config.patch_size, last_zero_init=False, num_classes=0, @@ -262,8 +262,8 @@ def define_G( return net elif G_netG == "img2img_turbo": net = Img2ImgTurbo( - in_channels=model_input_nc, - out_channels=model_output_nc, + in_channel=model_input_nc, + out_channel=model_output_nc, lora_rank_unet=G_lora_unet, lora_rank_vae=G_lora_vae, ) diff --git a/models/modules/hdit/hdit.py b/models/modules/hdit/hdit.py index dae6acff6..a1ec108de 100644 --- a/models/modules/hdit/hdit.py +++ b/models/modules/hdit/hdit.py @@ -741,8 +741,8 @@ def __init__( self, levels, mapping, - in_channels, - out_channels, + in_channel, + out_channel, patch_size, last_zero_init=True, num_classes=0, @@ -752,7 +752,8 @@ def __init__( ): super().__init__() self.num_classes = num_classes - self.patch_in = TokenMerge(in_channels, levels[0].width, patch_size) + self.out_channel = out_channel + self.patch_in = TokenMerge(in_channel, levels[0].width, patch_size) self.mapping = tag_module( MappingNetwork( mapping.depth, mapping.width, mapping.d_ff, dropout=mapping.dropout @@ -820,9 +821,7 @@ def __init__( ) self.out_norm = RMSNorm(levels[0].width) - self.patch_out = TokenSplitWithoutSkip( - levels[0].width, out_channels, patch_size - ) + self.patch_out = TokenSplitWithoutSkip(levels[0].width, out_channel, patch_size) if last_zero_init: nn.init.zeros_(self.patch_out.proj.weight) diff --git a/models/modules/img2img_turbo/img2img_turbo.py b/models/modules/img2img_turbo/img2img_turbo.py index 2efb3889a..a08ade86d 100644 --- a/models/modules/img2img_turbo/img2img_turbo.py +++ b/models/modules/img2img_turbo/img2img_turbo.py @@ -65,7 +65,7 @@ def my_vae_decoder_fwd(self, sample, latent_embeds=None): class Img2ImgTurbo(nn.Module): - def __init__(self, in_channels, out_channels, lora_rank_unet, lora_rank_vae): + def __init__(self, in_channel, out_channel, lora_rank_unet, lora_rank_vae): super().__init__() # TODO: other params