Skip to content

Commit

Permalink
[New Feature] ControlNet unit preset management (#1974)
Browse files Browse the repository at this point in the history
* WIP preset

* basic func done

* Add refresh button

* fix active units

* fix hint issue
  • Loading branch information
huchenlei authored Aug 25, 2023
1 parent b875b84 commit 6dc7b4f
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 16 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,7 @@ annotator/downloads/
# test results and expectations
web_tests/results/
web_tests/expectations/
*_diff.png
*_diff.png

# Presets
presets/
20 changes: 20 additions & 0 deletions javascript/active_units.js
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
this.attachImageUploadListener();
this.attachImageStateChangeObserver();
this.attachA1111SendInfoObserver();
this.attachPresetDropdownObserver();

// Initial updates:
if (this.isImg2Img)
Expand Down Expand Up @@ -300,6 +301,25 @@
});
}
}

attachPresetDropdownObserver() {
const presetDropDown = this.tab.querySelector('.cnet-preset-dropdown');

new MutationObserver((mutationsList) => {
for (const mutation of mutationsList) {
if (mutation.removedNodes.length > 0) {
setTimeout(() => {
this.updateActiveState();
this.updateActiveUnitCount();
}, 100);
return;
}
}
}).observe(presetDropDown, {
childList: true,
subtree: true,
});
}
}

gradioApp().querySelectorAll('#controlnet').forEach(accordion => {
Expand Down
2 changes: 2 additions & 0 deletions javascript/hints.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
'📝': 'Open new canvas',
'📷': 'Enable webcam',
'⇄': 'Mirror webcam',
'💾': 'Save preset',
'🗑️': 'Delete preset',
};

onUiUpdate(function () {
Expand Down
21 changes: 8 additions & 13 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,11 @@
)
from scripts.logging import logger
from scripts.controlnet_ui.openpose_editor import OpenposeEditor
from scripts.controlnet_ui.preset import ControlNetPresetUI
from scripts.controlnet_ui.tool_button import ToolButton
from modules import shared
from modules.ui_components import FormRow


class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""

def __init__(self, **kwargs):
super().__init__(variant="tool",
elem_classes=kwargs.pop('elem_classes', []) + ["cnet-toolbutton"],
**kwargs)

def get_block_name(self):
return "button"


class UiControlNetUnit(external_code.ControlNetUnit):
"""The data class that stores all states of a ControlNetUnit."""

Expand Down Expand Up @@ -147,6 +136,7 @@ def __init__(
self.loopback = None
self.use_preview_as_input = None
self.openpose_editor = None
self.preset_panel = None
self.upload_independent_img_in_img2img = None
self.image_upload_panel = None

Expand Down Expand Up @@ -413,6 +403,8 @@ def render(self, tabname: str, elem_id_tabname: str, is_img2img: bool) -> None:
visible=not is_img2img
)

self.preset_panel = ControlNetPresetUI(id_prefix=f"{elem_id_tabname}_{tabname}_")

def register_send_dimensions(self, is_img2img: bool):
"""Register event handler for send dimension button."""

Expand Down Expand Up @@ -776,6 +768,9 @@ def register_callbacks(self, is_img2img: bool):
self.openpose_editor.register_callbacks(
self.generated_image, self.use_preview_as_input
)
self.preset_panel.register_callbacks(*[
getattr(self, key) for key in vars(external_code.ControlNetUnit()).keys()
])
if is_img2img:
self.register_img2img_same_input()

Expand Down
212 changes: 212 additions & 0 deletions scripts/controlnet_ui/preset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import os
import gradio as gr

from typing import Dict, List

from modules import scripts
from scripts.infotext import parse_unit, serialize_unit
from scripts.controlnet_ui.tool_button import ToolButton
from scripts.logging import logger
from scripts import external_code

save_symbol = "\U0001f4be" # 💾
delete_symbol = "\U0001f5d1\ufe0f" # 🗑️
refresh_symbol = "\U0001f504" # 🔄

NEW_PRESET = "New Preset"


def load_presets(preset_dir: str) -> Dict[str, str]:
if not os.path.exists(preset_dir):
os.makedirs(preset_dir)
return {}

presets = {}
for filename in os.listdir(preset_dir):
if filename.endswith(".txt"):
with open(os.path.join(preset_dir, filename), "r") as f:
name = filename.replace(".txt", "")
if name == NEW_PRESET:
continue
presets[name] = f.read()
return presets


class ControlNetPresetUI(object):
preset_directory = os.path.join(scripts.basedir(), "presets")
presets = load_presets(preset_directory)

def __init__(self, id_prefix: str):
self.dropdown = None
self.save_button = None
self.delete_button = None
self.refresh_button = None
self.preset_name = None
self.confirm_preset_name = None
self.name_dialog = None
self.render(id_prefix)

def render(self, id_prefix: str):
with gr.Row():
self.dropdown = gr.Dropdown(
label="Presets",
show_label=True,
elem_classes=["cnet-preset-dropdown"],
choices=ControlNetPresetUI.dropdown_choices(),
value=NEW_PRESET,
)
self.save_button = ToolButton(
value=save_symbol,
elem_classes=["cnet-preset-save"],
tooltip="Save preset",
)
self.delete_button = ToolButton(
value=delete_symbol,
elem_classes=["cnet-preset-delete"],
tooltip="Delete preset",
)
self.refresh_button = ToolButton(
value=refresh_symbol,
elem_classes=["cnet-preset-refresh"],
tooltip="Refresh preset",
)

with gr.Box(
elem_classes=["popup-dialog", "cnet-preset-enter-name"],
elem_id=f"{id_prefix}_cnet_preset_enter_name",
) as self.name_dialog:
with gr.Row():
self.preset_name = gr.Textbox(
label="Preset name",
show_label=True,
lines=1,
elem_classes=["cnet-preset-name"],
)
self.confirm_preset_name = ToolButton(
value=save_symbol,
elem_classes=["cnet-preset-confirm-name"],
tooltip="Save preset",
)

def register_callbacks(self, *ui_states):
self.dropdown.change(
fn=ControlNetPresetUI.apply_preset,
inputs=[self.dropdown],
outputs=[self.delete_button, *ui_states],
show_progress=False,
)

def save_preset(name: str, *ui_states):
if name == NEW_PRESET:
return gr.update(visible=True), gr.update()

ControlNetPresetUI.save_preset(
name, external_code.ControlNetUnit(*ui_states)
)
return gr.update(), gr.update(
choices=ControlNetPresetUI.dropdown_choices(), value=name
)

self.save_button.click(
fn=save_preset,
inputs=[self.dropdown, *ui_states],
outputs=[self.name_dialog, self.dropdown],
show_progress=False,
).then(
fn=None,
_js=f"""
(name) => {{
if (name === "{NEW_PRESET}")
popup(gradioApp().getElementById('{self.name_dialog.elem_id}'));
}}""",
inputs=[self.dropdown],
)

def delete_preset(name: str):
ControlNetPresetUI.delete_preset(name)
return gr.Dropdown.update(
choices=ControlNetPresetUI.dropdown_choices(),
value=NEW_PRESET,
)

self.delete_button.click(
fn=delete_preset,
inputs=[self.dropdown],
outputs=[self.dropdown],
show_progress=False,
)

self.name_dialog.visible = False

def save_new_preset(new_name: str, *ui_states):
if new_name == NEW_PRESET:
logger.warn(f"Cannot save preset with reserved name '{NEW_PRESET}'")
return gr.update(visible=False), gr.update()

ControlNetPresetUI.save_preset(
new_name, external_code.ControlNetUnit(*ui_states)
)
return gr.update(visible=False), gr.update(
choices=ControlNetPresetUI.dropdown_choices(), value=new_name
)

self.confirm_preset_name.click(
fn=save_new_preset,
inputs=[self.preset_name, *ui_states],
outputs=[self.name_dialog, self.dropdown],
show_progress=False,
).then(fn=None, _js="closePopup")

self.refresh_button.click(
fn=ControlNetPresetUI.refresh_preset,
inputs=None,
outputs=[self.dropdown],
show_progress=False,
)

@staticmethod
def dropdown_choices() -> List[str]:
return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET]

@staticmethod
def save_preset(name: str, unit: external_code.ControlNetUnit):
infotext = serialize_unit(unit)
with open(
os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w"
) as f:
f.write(infotext)

ControlNetPresetUI.presets[name] = infotext

@staticmethod
def delete_preset(name: str):
if name not in ControlNetPresetUI.presets:
return

del ControlNetPresetUI.presets[name]

file = os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt")
if os.path.exists(file):
os.unlink(file)

@staticmethod
def apply_preset(name: str):
if name == NEW_PRESET:
return (
gr.update(visible=False),
*((gr.update(),) * len(vars(external_code.ControlNetUnit()).keys())),
)

assert name in ControlNetPresetUI.presets
infotext = ControlNetPresetUI.presets[name]
unit = parse_unit(infotext)

return gr.update(visible=True), *[
gr.update(value=value) if value is not None else gr.update()
for value in vars(unit).values()
]

@staticmethod
def refresh_preset():
ControlNetPresetUI.presets = load_presets(ControlNetPresetUI.preset_directory)
return gr.update(choices=ControlNetPresetUI.dropdown_choices())
12 changes: 12 additions & 0 deletions scripts/controlnet_ui/tool_button.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import gradio as gr

class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""

def __init__(self, **kwargs):
super().__init__(variant="tool",
elem_classes=kwargs.pop('elem_classes', []) + ["cnet-toolbutton"],
**kwargs)

def get_block_name(self):
return "button"
3 changes: 1 addition & 2 deletions scripts/infotext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from scripts import external_code
from scripts.logging import logger
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup


def field_to_displaytext(fieldname: str) -> str:
Expand Down Expand Up @@ -63,7 +62,7 @@ def __init__(self) -> None:
def unit_prefix(unit_index: int) -> str:
return f"ControlNet {unit_index}"

def register_unit(self, unit_index: int, uigroup: ControlNetUiGroup) -> None:
def register_unit(self, unit_index: int, uigroup) -> None:
"""Register the unit's UI group. By regsitering the unit, A1111 will be
able to paste values from infotext to IOComponents.
Expand Down

0 comments on commit 6dc7b4f

Please sign in to comment.