Skip to content

Commit

Permalink
Repair infotext parsing
Browse files Browse the repository at this point in the history
🚧 Listen png info buttons

WIP

New infotext format

nits

nits

nits

rebase

refactor active unit counts
  • Loading branch information
catboxanon authored and huchenlei committed Aug 24, 2023
1 parent 7140c7f commit 288e0b3
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 136 deletions.
94 changes: 59 additions & 35 deletions javascript/active_units.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
* Disable resize mode selection when A1111 img2img input is used.
*/
(function () {
const cnetAllUnits = new Map/* <Element, ControlNetUnitTab> */();
const cnetAllAccordions = new Set();
onUiUpdate(() => {
const ImgChangeType = {
Expand Down Expand Up @@ -64,8 +63,9 @@
}

class ControlNetUnitTab {
constructor(tab) {
constructor(tab, accordion) {
this.tab = tab;
this.accordion = accordion;
this.isImg2Img = tab.querySelector('.cnet-unit-enabled').id.includes('img2img');

this.enabledCheckbox = tab.querySelector('.cnet-unit-enabled input');
Expand All @@ -84,6 +84,7 @@
this.attachTabNavChangeObserver();
this.attachImageUploadListener();
this.attachImageStateChangeObserver();
this.attachA1111SendInfoObserver();

// Initial updates:
if (this.isImg2Img)
Expand Down Expand Up @@ -114,6 +115,34 @@
}
}

updateActiveUnitCount() {
function getActiveUnitCount(checkboxes) {
let activeUnitCount = 0;
for (const checkbox of checkboxes) {
if (checkbox.checked)
activeUnitCount++;
}
return activeUnitCount;
}

const checkboxes = this.accordion.querySelectorAll('.cnet-unit-enabled input');
const span = this.accordion.querySelector('.label-wrap span');

// Remove existing badge.
if (span.childNodes.length !== 1) {
span.removeChild(span.lastChild);
}
// Add new badge if necessary.
const activeUnitCount = getActiveUnitCount(checkboxes);
if (activeUnitCount > 0) {
const div = document.createElement('div');
div.classList.add('cnet-badge');
div.classList.add('primary');
div.innerHTML = `${activeUnitCount} unit${activeUnitCount > 1 ? 's' : ''}`;
span.appendChild(div);
}
}

/**
* Add the active control type to tab displayed text.
*/
Expand Down Expand Up @@ -182,6 +211,7 @@
attachEnabledButtonListener() {
this.enabledCheckbox.addEventListener('change', () => {
this.updateActiveState();
this.updateActiveUnitCount();
});
}

Expand Down Expand Up @@ -243,45 +273,39 @@
subtree: true,
});
}
}

gradioApp().querySelectorAll('.cnet-unit-tab').forEach(tab => {
if (cnetAllUnits.has(tab)) return;
cnetAllUnits.set(tab, new ControlNetUnitTab(tab));
});

function getActiveUnitCount(checkboxes) {
let activeUnitCount = 0;
for (const checkbox of checkboxes) {
if (checkbox.checked)
activeUnitCount++;
/**
* Observe send PNG info buttons in A1111, as they can also directly
* set states of ControlNetUnit.
*/
attachA1111SendInfoObserver() {
const pasteButtons = gradioApp().querySelectorAll('#paste');
const pngButtons = gradioApp().querySelectorAll(
this.isImg2Img ?
'#img2img_tab, #inpaint_tab' :
'#txt2img_tab'
);

for (const button of [...pasteButtons, ...pngButtons]) {
button.addEventListener('click', () => {
// The paste/send img generation info feature goes
// though gradio, which is pretty slow. Ideally we should
// observe the event when gradio has done the job, but
// that is not an easy task.
// Here we just do a 2 second delay until the refresh.
setTimeout(() => {
this.updateActiveState();
this.updateActiveUnitCount();
}, 2000);
});
}
}
return activeUnitCount;
}

gradioApp().querySelectorAll('#controlnet').forEach(accordion => {
if (cnetAllAccordions.has(accordion)) return;
const checkboxes = accordion.querySelectorAll('.cnet-unit-enabled input');
if (!checkboxes) return;

const span = accordion.querySelector('.label-wrap span');
checkboxes.forEach(checkbox => {
checkbox.addEventListener('change', () => {
// Remove existing badge.
if (span.childNodes.length !== 1) {
span.removeChild(span.lastChild);
}
// Add new badge if necessary.
const activeUnitCount = getActiveUnitCount(checkboxes);
if (activeUnitCount > 0) {
const div = document.createElement('div');
div.classList.add('cnet-badge');
div.classList.add('primary');
div.innerHTML = `${activeUnitCount} unit${activeUnitCount > 1 ? 's' : ''}`;
span.appendChild(div);
}
});
});
accordion.querySelectorAll('.cnet-unit-tab')
.forEach(tab => new ControlNetUnitTab(tab, accordion));
cnetAllAccordions.add(accordion);
});
});
Expand Down
64 changes: 25 additions & 39 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gc
import os
import logging
import re
from collections import OrderedDict
from copy import copy
from typing import Dict, Optional, Tuple
Expand All @@ -21,6 +22,7 @@
from scripts.logging import logger
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img
from modules.images import save_image
from scripts.infotext import Infotext

import cv2
import numpy as np
Expand Down Expand Up @@ -246,24 +248,23 @@ def get_default_ui_unit(is_ui=True):
model="None"
)

def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str):
def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str) -> Tuple[ControlNetUiGroup, gr.State]:
group = ControlNetUiGroup(
gradio_compat,
self.infotext_fields,
Script.get_default_ui_unit(),
self.preprocessor,
)
group.render(tabname, elem_id_tabname, is_img2img)
group.register_callbacks(is_img2img)
return group.render_and_register_unit(tabname, is_img2img)
return group, group.render_and_register_unit(tabname, is_img2img)

def ui(self, is_img2img):
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing.
Values of those returned components will be passed to run() and process() functions.
"""
self.infotext_fields = []
self.paste_field_names = []
infotext = Infotext()

controls = ()
max_models = shared.opts.data.get("control_net_max_models_num", 1)
elem_id_tabname = ("img2img" if is_img2img else "txt2img") + "_controlnet"
Expand All @@ -274,14 +275,18 @@ def ui(self, is_img2img):
for i in range(max_models):
with gr.Tab(f"ControlNet Unit {i}",
elem_classes=['cnet-unit-tab']):
controls += (self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname),)
group, state = self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname)
infotext.register_unit(i, group)
controls += (state,)
else:
with gr.Column():
controls += (self.uigroup(f"ControlNet", is_img2img, elem_id_tabname),)
group, state = self.uigroup(f"ControlNet", is_img2img, elem_id_tabname)
infotext.register_unit(0, group)
controls += (state,)

if shared.opts.data.get("control_net_sync_field_args", False):
for _, field_name in self.infotext_fields:
self.paste_field_names.append(field_name)
if shared.opts.data.get("control_net_sync_field_args", True):
self.infotext_fields = infotext.infotext_fields
self.paste_field_names = infotext.paste_field_names

return controls

Expand Down Expand Up @@ -562,39 +567,19 @@ def high_quality_resize(x, size):
@staticmethod
def get_enabled_units(p):
units = external_code.get_all_units_in_processing(p)
enabled_units = []

if len(units) == 0:
# fill a null group
remote_unit = Script.parse_remote_call(p, Script.get_default_ui_unit(), 0)
if remote_unit.enabled:
units.append(remote_unit)

for idx, unit in enumerate(units):
unit = Script.parse_remote_call(p, unit, idx)
if not unit.enabled:
continue

enabled_units.append(copy(unit))
if len(units) != 1:
log_key = f"ControlNet {idx}"
else:
log_key = "ControlNet"

log_value = {
"preprocessor": unit.module,
"model": unit.model,
"weight": unit.weight,
"starting/ending": str((unit.guidance_start, unit.guidance_end)),
"resize mode": str(unit.resize_mode),
"pixel perfect": str(unit.pixel_perfect),
"control mode": str(unit.control_mode),
"preprocessor params": str((unit.processor_res, unit.threshold_a, unit.threshold_b)),
}
log_value = str(log_value).replace('\'', '').replace('{', '').replace('}', '')

p.extra_generation_params.update({log_key: log_value})


enabled_units = [
local_unit
for idx, unit in enumerate(units)
for local_unit in (Script.parse_remote_call(p, unit, idx),)
if local_unit.enabled
]
Infotext.write_infotext(enabled_units, p)
return enabled_units

@staticmethod
Expand Down Expand Up @@ -1064,7 +1049,7 @@ def on_ui_settings():
shared.opts.add_option("control_net_allow_script_control", shared.OptionInfo(
False, "Allow other script to control this extension", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("control_net_sync_field_args", shared.OptionInfo(
False, "Passing ControlNet parameters with \"Send to img2img\"", gr.Checkbox, {"interactive": True}, section=section))
True, "Paste ControlNet parameters in infotext", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_show_batch_images_in_ui", shared.OptionInfo(
False, "Show batch images in gradio gallery output", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_increment_seed_during_batch", shared.OptionInfo(
Expand All @@ -1080,4 +1065,5 @@ def on_ui_settings():

batch_hijack.instance.do_hijack()
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted)
script_callbacks.on_after_component(ControlNetUiGroup.on_after_component)
Loading

0 comments on commit 288e0b3

Please sign in to comment.