Skip to content

Commit

Permalink
🐛 Fix SDXL middle block missing issue
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Aug 22, 2023
1 parent fc3ce12 commit 8f1874a
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions scripts/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def guidance_schedule_handler(self, x):
param.guidance_stopped = current_sampling_percent < param.start_guidance_percent or current_sampling_percent > param.stop_guidance_percent

def hook(self, model, sd_ldm, control_params, process):
is_sdxl = getattr(process.sd_model, 'is_sdxl', False)

self.model = model
self.sd_ldm = sd_ldm
self.control_params = control_params
Expand All @@ -361,8 +363,6 @@ def process_sample(*args, **kwargs):
return process.sample_before_CN_hack(*args, **kwargs)

def forward(self, x, timesteps=None, context=None, **kwargs):
total_controlnet_embedding = [0.0] * 13
total_t2i_adapter_embedding = [0.0] * 4
require_inpaint_hijack = False
is_in_high_res_fix = False
batch_size = int(x.shape[0])
Expand Down Expand Up @@ -429,6 +429,11 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
context = torch.cat([context, control.clone()], dim=1)

# handle ControlNet / T2I_Adapter
input_block_num = 9 if is_sdxl else 12
assert input_block_num % 3 == 0
total_controlnet_embedding = [0.0] * (input_block_num + 1)
total_t2i_adapter_embedding = [0.0] * (input_block_num // 3 + 1)

for param in outer.control_params:
if no_high_res_control:
continue
Expand Down Expand Up @@ -611,8 +616,11 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack)

# U-Net Decoder
for module, (hs_item, controlnet_embedding_item) in zip(self.output_blocks, reversed(list(zip(hs, total_controlnet_embedding)))):
h = th.cat([h, aligned_adding(hs_item, controlnet_embedding_item, require_inpaint_hijack)], dim=1)
assert len(hs) == len(total_controlnet_embedding), \
f"Misaligned `hs` and `controlnet_embedding`\n{len(hs)} != {len(total_controlnet_embedding)}"

for i, module in enumerate(self.output_blocks):
h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1)
h = module(h, emb, context)

# U-Net Output
Expand Down Expand Up @@ -774,7 +782,7 @@ def hacked_group_norm_forward(self, *args, **kwargs):
gn_modules = [model.middle_block]
model.middle_block.gn_weight = 0

input_block_indices = [4, 5, 7, 8, 10, 11] if not getattr(process.sd_model, 'is_sdxl', False) else [4, 5, 7, 8]
input_block_indices = [4, 5, 7, 8, 10, 11] if not is_sdxl else [4, 5, 7, 8]
for w, i in enumerate(input_block_indices):
module = model.input_blocks[i]
module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
Expand Down

0 comments on commit 8f1874a

Please sign in to comment.