Skip to content

Commit

Permalink
New infotext format
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Aug 23, 2023
1 parent 2e02e78 commit 80ac188
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 43 deletions.
57 changes: 36 additions & 21 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
label="Enable",
value=self.default_unit.enabled,
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_enable_checkbox",
elem_classes=['cnet-unit-enabled'],
elem_classes=["cnet-unit-enabled"],
)
self.low_vram = gr.Checkbox(
label="Low VRAM",
Expand Down Expand Up @@ -445,9 +445,7 @@ def webcam_mirror_toggle():
self.webcam_mirrored = not self.webcam_mirrored
return {"mirror_webcam": self.webcam_mirrored, "__type__": "update"}

self.webcam_mirror.click(
webcam_mirror_toggle, inputs=None, outputs=self.image
)
self.webcam_mirror.click(webcam_mirror_toggle, inputs=None, outputs=self.image)

def register_refresh_all_models(self):
def refresh_all_models(*inputs):
Expand All @@ -466,21 +464,34 @@ def register_build_sliders(self):
return

def build_sliders(module, pp):
default_res_slider_config = dict(
label=flag_preprocessor_resolution,
value=512,
minimum=64,
maximum=2048,
step=1,
)
# Clear old slider values so that they do not cause confusion in
# infotext.
clear_slider_update = gr.update(
visible=False,
interactive=False,
minimum=-1,
maximum=-1,
value=-1,
)

grs = []
module = global_state.get_module_basename(module)
if module not in preprocessor_sliders_config:
grs += [
gr.update(
label=flag_preprocessor_resolution,
value=512,
minimum=64,
maximum=2048,
step=1,
**default_res_slider_config,
visible=not pp,
interactive=not pp,
),
gr.update(visible=False, interactive=False),
gr.update(visible=False, interactive=False),
clear_slider_update,
clear_slider_update,
gr.update(visible=True),
]
else:
Expand All @@ -503,9 +514,9 @@ def build_sliders(module, pp):
)
)
else:
grs.append(gr.update(visible=False, interactive=False))
grs.append(clear_slider_update)
while len(grs) < 3:
grs.append(gr.update(visible=False, interactive=False))
grs.append(clear_slider_update)
grs.append(gr.update(visible=True))
if module in model_free_preprocessors:
grs += [
Expand Down Expand Up @@ -535,13 +546,17 @@ def filter_selected(k, pp):
filtered_preprocessor_list,
filtered_model_list,
default_option,
default_model
) = global_state.select_control_type(k)
default_model,
) = global_state.select_control_type(k)
return [
gr.Dropdown.update(value=default_option, choices=filtered_preprocessor_list),
gr.Dropdown.update(value=default_model, choices=filtered_model_list),
gr.Dropdown.update(
value=default_option, choices=filtered_preprocessor_list
),
gr.Dropdown.update(
value=default_model, choices=filtered_model_list
),
] + build_sliders(default_option, pp)

self.type_filter.change(
filter_selected,
inputs=[self.type_filter, self.pixel_perfect],
Expand All @@ -552,9 +567,9 @@ def register_run_annotator(self, is_img2img: bool):
def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm):
if image is None:
return (
gr.update(value=None, visible=True),
gr.update(),
*self.openpose_editor.update(''),
gr.update(value=None, visible=True),
gr.update(),
*self.openpose_editor.update(""),
)

img = HWC3(image["image"])
Expand Down
89 changes: 67 additions & 22 deletions scripts/infotext.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from typing import List, Tuple
from typing import List, Tuple, Union

import gradio as gr

Expand All @@ -10,6 +9,45 @@
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup


def field_to_displaytext(fieldname: str) -> str:
return " ".join([word.capitalize() for word in fieldname.split("_")])


def displaytext_to_field(text: str) -> str:
return "_".join([word.lower() for word in text.split(" ")])


def parse_value(value: str) -> Union[str, float, int, bool]:
if value in ('True', 'False'):
return value == 'True'
try:
return int(value)
except ValueError:
try:
return float(value)
except ValueError:
return value # Plain string.


def serialize_unit(unit: external_code.ControlNetUnit) -> str:
log_value = {
field_to_displaytext(field): getattr(unit, field)
for field in vars(external_code.ControlNetUnit()).keys()
if field not in ("image", "enabled") and getattr(unit, field) != -1
# Note: exclude hidden slider values.
}
assert all("," not in str(v) and ":" not in str(v) for v in log_value.values())
return ", ".join(f"{field}: {value}" for field, value in log_value.items())


def parse_unit(text: str) -> external_code.ControlNetUnit:
return external_code.ControlNetUnit(enabled=True, **{
displaytext_to_field(key): parse_value(value)
for item in text.split(',')
for (key, value) in (item.strip().split(': '),)
})


class Infotext(object):
def __init__(self) -> None:
self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
Expand All @@ -19,15 +57,21 @@ def __init__(self) -> None:
def unit_prefix(unit_index: int) -> str:
return f"ControlNet {unit_index}"

@staticmethod
def field_to_displaytext(fieldname: str) -> str:
return " ".join([word.capitalize() for word in fieldname.split("_")])

def register_unit(self, unit_index: int, uigroup: ControlNetUiGroup):
def register_unit(self, unit_index: int, uigroup: ControlNetUiGroup) -> None:
"""Register the unit's UI group. By regsitering the unit, A1111 will be
able to paste values from infotext to IOComponents."""
able to paste values from infotext to IOComponents.
Args:
unit_index: The index of the ControlNet unit
uigroup: The ControlNetUiGroup instance that contains all gradio
iocomponents.
"""
unit_prefix = Infotext.unit_prefix(unit_index)
for field in vars(external_code.ControlNetUnit()).keys():
# Exclude image for infotext.
if field == "image":
continue

# Every field in ControlNetUnit should have a cooresponding
# IOComponent in ControlNetUiGroup.
io_component = getattr(uigroup, field)
Expand All @@ -37,35 +81,36 @@ def register_unit(self, unit_index: int, uigroup: ControlNetUiGroup):

@staticmethod
def write_infotext(
units: external_code.ControlNetUnit, p: StableDiffusionProcessing
):
units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing
):
"""Write infotext to `p`."""
p.extra_generation_params.update(
{
Infotext.unit_prefix(i): json.dumps({
k: v
for k, v in vars(unit).items()
# Write everything except `image` to infotext.
if k in vars(external_code.ControlNetUnit()) and k != 'image'
})
Infotext.unit_prefix(i): serialize_unit(unit)
for i, unit in enumerate(units)
if unit.enabled
}
)

@staticmethod
def on_infotext_pasted(infotext: str, results: dict) -> None:
""" """
"""Parse ControlNet infotext string and write result to `results` dict."""
updates = {}
for k, v in results.items():
if not k.startswith("ControlNet"):
continue

assert isinstance(v, str), f"Expect string but got {v}."
try:
for field, value in json.loads(v).items():
for field, value in vars(parse_unit(v)).items():
if field == "image":
continue

assert value is not None, f"{field} == None"
component_locator = f"{k} {field}"
updates[component_locator] = value
print(f'Setting {component_locator} = {value}')
except json.JSONDecodeError as e:
logger.warn(f"Failed to parse infotext:\n{v}\n{e}")
logger.info(f"InfoText: Setting {component_locator} = {value}")
except Exception as e:
logger.warn(f"Failed to parse infotext, legacy format infotext is no longer supported:\n{v}\n{e}")

results.update(updates)

0 comments on commit 80ac188

Please sign in to comment.