From 8f1874afc24f2903dacaf93979645ecaf03c1e8b Mon Sep 17 00:00:00 2001 From: huchenlei Date: Tue, 22 Aug 2023 13:22:52 -0400 Subject: [PATCH] :bug: Fix SDXL middle block missing issue --- scripts/hook.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/scripts/hook.py b/scripts/hook.py index dde59669c..5aeb0d004 100644 --- a/scripts/hook.py +++ b/scripts/hook.py @@ -340,6 +340,8 @@ def guidance_schedule_handler(self, x): param.guidance_stopped = current_sampling_percent < param.start_guidance_percent or current_sampling_percent > param.stop_guidance_percent def hook(self, model, sd_ldm, control_params, process): + is_sdxl = getattr(process.sd_model, 'is_sdxl', False) + self.model = model self.sd_ldm = sd_ldm self.control_params = control_params @@ -361,8 +363,6 @@ def process_sample(*args, **kwargs): return process.sample_before_CN_hack(*args, **kwargs) def forward(self, x, timesteps=None, context=None, **kwargs): - total_controlnet_embedding = [0.0] * 13 - total_t2i_adapter_embedding = [0.0] * 4 require_inpaint_hijack = False is_in_high_res_fix = False batch_size = int(x.shape[0]) @@ -429,6 +429,11 @@ def forward(self, x, timesteps=None, context=None, **kwargs): context = torch.cat([context, control.clone()], dim=1) # handle ControlNet / T2I_Adapter + input_block_num = 9 if is_sdxl else 12 + assert input_block_num % 3 == 0 + total_controlnet_embedding = [0.0] * (input_block_num + 1) + total_t2i_adapter_embedding = [0.0] * (input_block_num // 3 + 1) + for param in outer.control_params: if no_high_res_control: continue @@ -611,8 +616,11 @@ def forward(self, x, timesteps=None, context=None, **kwargs): h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack) # U-Net Decoder - for module, (hs_item, controlnet_embedding_item) in zip(self.output_blocks, reversed(list(zip(hs, total_controlnet_embedding)))): - h = th.cat([h, aligned_adding(hs_item, controlnet_embedding_item, require_inpaint_hijack)], dim=1) + assert len(hs) == len(total_controlnet_embedding), \ + f"Misaligned `hs` and `controlnet_embedding`\n{len(hs)} != {len(total_controlnet_embedding)}" + + for i, module in enumerate(self.output_blocks): + h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1) h = module(h, emb, context) # U-Net Output @@ -774,7 +782,7 @@ def hacked_group_norm_forward(self, *args, **kwargs): gn_modules = [model.middle_block] model.middle_block.gn_weight = 0 - input_block_indices = [4, 5, 7, 8, 10, 11] if not getattr(process.sd_model, 'is_sdxl', False) else [4, 5, 7, 8] + input_block_indices = [4, 5, 7, 8, 10, 11] if not is_sdxl else [4, 5, 7, 8] for w, i in enumerate(input_block_indices): module = model.input_blocks[i] module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))