Skip to content

Commit

Permalink
Add advanced weighting support (#754)
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored May 23, 2024
1 parent 49c3a08 commit eb1e12b
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(
self.hr_option = None
self.batch_image_dir_state = None
self.output_dir_state = None
self.advanced_weighting = gr.State(None)

# Internal states for UI state pasting.
self.prevent_next_n_module_update = 0
Expand Down Expand Up @@ -607,6 +608,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
self.guidance_end,
self.pixel_perfect,
self.control_mode,
self.advanced_weighting,
)

unit = gr.State(self.default_unit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,21 @@ class ControlNetUnit:
pixel_perfect: bool = False
# Control mode for the unit; defaults to balanced.
control_mode: ControlMode = ControlMode.BALANCED
# Weight for each layer of ControlNet params.
# For ControlNet:
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
# For T2IAdapter
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
# - SDXL: 4 weights (3 encoder block + 1 middle block)
# For IPAdapter
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
# Note2: The field `weight` is still used in some places, e.g. reference_only,
# even advanced_weighting is set.
advanced_weighting: Optional[List[float]] = None

# Following fields should only be used in the API.
# ====== Start of API only fields ======
Expand Down
6 changes: 6 additions & 0 deletions extensions-builtin/sd_forge_controlnet/scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self):
self.control_cond_for_hr_fix = None
self.control_mask = None
self.control_mask_for_hr_fix = None
self.advanced_weighting = None


class ControlNetForForgeOfficial(scripts.Script):
Expand Down Expand Up @@ -505,6 +506,11 @@ def process_unit_before_every_sampling(self,
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy()

if unit.advanced_weighting is not None:
if params.model.positive_advanced_weighting is None:
logger.warn("advanced_weighting overwrite control_mode")
params.model.positive_advanced_weighting = unit.advanced_weighting

cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)

params.model.advanced_mask_weighting = mask
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from .template import (
APITestTemplate,
realistic_girl_face_img,
disable_in_cq,
get_model,
)


@disable_in_cq
def test_ipadapter_advanced_weighting():
weights = [0.0] * 16 # 16 weights for SD15 / 11 weights for SDXL
# SD15 composition
weights[4] = 0.25
weights[5] = 1.0

APITestTemplate(
"test_ipadapter_advanced_weighting",
"txt2img",
payload_overrides={
"width": 512,
"height": 512,
},
unit_overrides={
"image": realistic_girl_face_img,
"module": "CLIP-ViT-H (IPAdapter)",
"model": get_model("ip-adapter_sd15"),
"advanced_weighting": weights,
},
).exec()

APITestTemplate(
"test_ipadapter_advanced_weighting_ref",
"txt2img",
payload_overrides={
"width": 512,
"height": 512,
},
unit_overrides={
"image": realistic_girl_face_img,
"module": "CLIP-ViT-H (IPAdapter)",
"model": get_model("ip-adapter_sd15"),
},
).exec()
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ def get_model(model_name: str) -> str:


default_unit = {
"control_mode": 0,
"control_mode": "Balanced",
"enabled": True,
"guidance_end": 1,
"guidance_start": 0,
"pixel_perfect": True,
"processor_res": 512,
"resize_mode": 1,
"resize_mode": "Crop and Resize",
"threshold_a": 64,
"threshold_b": 64,
"weight": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __call__(self, n, context_attn2, value_attn2, extra_options):
batch_prompt = b // len(cond_or_uncond)
out = optimized_attention(q, k, v, extra_options["n_heads"])
_, _, lh, lw = extra_options["original_shape"]

for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_type, self.sigma_start, self.sigma_end, self.unfold_batch):
if sigma > sigma_start or sigma < sigma_end:
continue
Expand Down Expand Up @@ -466,8 +466,18 @@ def __call__(self, n, context_attn2, value_attn2, extra_options):
ip_v = ip_v_offset + ip_v_mean * W

out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"])
if weight_type.startswith("original"):
out_ip = out_ip * weight

if weight_type == "original":
assert isinstance(weight, (float, int))
weight = weight
elif weight_type == "advanced":
assert isinstance(weight, list)
transformer_index: int = extra_options["transformer_index"]
assert transformer_index < len(weight)
weight = weight[transformer_index]
else:
weight = 1.0
out_ip = out_ip * weight

if mask is not None:
# TODO: needs checking
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,17 @@ def __init__(self, state_dict):

def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unet = process.sd_model.forge_objects.unet
if self.positive_advanced_weighting is None:
weight = self.strength
cond["weight_type"] = "original"
else:
weight = self.positive_advanced_weighting
cond["weight_type"] = "advanced"

unet = opIPAdapterApply(
ipadapter=self.ip_adapter,
model=unet,
weight=self.strength,
weight=weight,
start_at=self.start_percent,
end_at=self.end_percent,
faceid_v2=self.faceid_v2,
Expand Down

0 comments on commit eb1e12b

Please sign in to comment.