Skip to content

Commit

Permalink
Update control type in preset (#1980)
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored Aug 25, 2023
1 parent cc2cfc0 commit 18bbe35
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
1 change: 1 addition & 0 deletions javascript/active_units.js
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
setTimeout(() => {
this.updateActiveState();
this.updateActiveUnitCount();
this.updateActiveControlType();
}, 100);
return;
}
Expand Down
2 changes: 1 addition & 1 deletion scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def register_callbacks(self, is_img2img: bool):
self.openpose_editor.register_callbacks(
self.generated_image, self.use_preview_as_input
)
self.preset_panel.register_callbacks(*[
self.preset_panel.register_callbacks(self.type_filter, *[
getattr(self, key) for key in vars(external_code.ControlNetUnit()).keys()
])
if is_img2img:
Expand Down
44 changes: 37 additions & 7 deletions scripts/controlnet_ui/preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from scripts.infotext import parse_unit, serialize_unit
from scripts.controlnet_ui.tool_button import ToolButton
from scripts.logging import logger
from scripts.processor import preprocessor_filters
from scripts import external_code

save_symbol = "\U0001f4be" # 💾
Expand All @@ -32,6 +33,20 @@ def load_presets(preset_dir: str) -> Dict[str, str]:
return presets


def infer_control_type(module: str, model: str) -> str:
control_types = preprocessor_filters.keys()
control_type_candidates = [
control_type
for control_type in control_types
if control_type.lower() in module or control_type.lower() in model
]
if len(control_type_candidates) != 1:
raise ValueError(
f"Unable to infer control type from module {module} and model {model}"
)
return control_type_candidates[0]


class ControlNetPresetUI(object):
preset_directory = os.path.join(scripts.basedir(), "presets")
presets = load_presets(preset_directory)
Expand Down Expand Up @@ -88,11 +103,11 @@ def render(self, id_prefix: str):
tooltip="Save preset",
)

def register_callbacks(self, *ui_states):
def register_callbacks(self, control_type: gr.Radio, *ui_states):
self.dropdown.change(
fn=ControlNetPresetUI.apply_preset,
inputs=[self.dropdown],
outputs=[self.delete_button, *ui_states],
outputs=[self.delete_button, control_type, *ui_states],
show_progress=False,
)

Expand Down Expand Up @@ -194,17 +209,32 @@ def apply_preset(name: str):
if name == NEW_PRESET:
return (
gr.update(visible=False),
*((gr.update(),) * len(vars(external_code.ControlNetUnit()).keys())),
*(
(gr.update(),)
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
),
)

assert name in ControlNetPresetUI.presets
infotext = ControlNetPresetUI.presets[name]
unit = parse_unit(infotext)

return gr.update(visible=True), *[
gr.update(value=value) if value is not None else gr.update()
for value in vars(unit).values()
]
try:
control_type_update = gr.update(
value=infer_control_type(unit.module, unit.model)
)
except ValueError as e:
logger.error(e)
control_type_update = gr.update()

return (
gr.update(visible=True),
control_type_update,
*[
gr.update(value=value) if value is not None else gr.update()
for value in vars(unit).values()
],
)

@staticmethod
def refresh_preset():
Expand Down

0 comments on commit 18bbe35

Please sign in to comment.