diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9a9ecacaa..29b31e288 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -94,10 +94,6 @@ jobs: wait-for-it --service 127.0.0.1:7860 -t 600 python -m pytest -v --junitxml=test/results.xml --cov ./extensions/sd-webui-controlnet --cov-report=xml --verify-base-url ./extensions/sd-webui-controlnet/tests working-directory: stable-diffusion-webui - - name: Run unit tests - run: | - python -m pytest -v ./unit_tests/ - working-directory: stable-diffusion-webui/extensions/sd-webui-controlnet/ - name: Kill test server if: always() run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10 diff --git a/README.md b/README.md index 5a226a508..c4dda14c2 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,12 @@ # ControlNet for Stable Diffusion WebUI The WebUI extension for ControlNet and other injection-based SD controls. -![image](https://github.com/Mikubill/sd-webui-controlnet/assets/20929282/261f9a50-ba9c-472f-b398-fced61929c4a) + +![image](https://github.com/Mikubill/sd-webui-controlnet/assets/20929282/51172d20-606b-4b9f-aba5-db2f2417cb0b) This extension is for AUTOMATIC1111's [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui), allows the Web UI to add [ControlNet](https://github.com/lllyasviel/ControlNet) to the original Stable Diffusion model to generate images. The addition is on-the-fly, the merging is not required. # News - -- [2024-05-04] 🔥[v1.1.447] PuLID [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2841] -- [2024-04-30] 🔥[v1.1.446] Effective region mask supported for ControlNet/IPAdapter [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2831] - [2024-04-27] 🔥ControlNet-lllite Normal Dsine released [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2813] - [2024-04-19] 🔥[v1.1.445] IPAdapter advanced weight [Instant Style] [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2770] - [2024-04-17] 🔥[v1.1.444] Marigold depth preprocessor [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2760] diff --git a/internal_controlnet/__init__.py b/internal_controlnet/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/internal_controlnet/args.py b/internal_controlnet/args.py deleted file mode 100644 index 50e32802e..000000000 --- a/internal_controlnet/args.py +++ /dev/null @@ -1,443 +0,0 @@ -from __future__ import annotations -import os -import torch -import numpy as np -from typing import Optional, List, Annotated, ClassVar, Callable, Any, Tuple, Union -from pydantic import BaseModel, validator, root_validator, Field -from PIL import Image -from logging import Logger -from copy import copy -from enum import Enum - -from scripts.enums import ( - InputMode, - ResizeMode, - ControlMode, - HiResFixOption, - PuLIDMode, -) - - -def _unimplemented_func(*args, **kwargs): - raise NotImplementedError("Not implemented.") - - -def field_to_displaytext(fieldname: str) -> str: - return " ".join([word.capitalize() for word in fieldname.split("_")]) - - -def displaytext_to_field(text: str) -> str: - return "_".join([word.lower() for word in text.split(" ")]) - - -def serialize_value(value) -> str: - if isinstance(value, Enum): - return value.value - return str(value) - - -def parse_value(value: str) -> Union[str, float, int, bool]: - if value in ("True", "False"): - return value == "True" - try: - return int(value) - except ValueError: - try: - return float(value) - except ValueError: - return value # Plain string. - - -class ControlNetUnit(BaseModel): - """ - Represents an entire ControlNet processing unit. - """ - - class Config: - arbitrary_types_allowed = True - extra = "ignore" - - cls_match_module: ClassVar[Callable[[str], bool]] = _unimplemented_func - cls_match_model: ClassVar[Callable[[str], bool]] = _unimplemented_func - cls_decode_base64: ClassVar[Callable[[str], np.ndarray]] = _unimplemented_func - cls_torch_load_base64: ClassVar[Callable[[Any], torch.Tensor]] = _unimplemented_func - cls_get_preprocessor: ClassVar[Callable[[str], Any]] = _unimplemented_func - cls_logger: ClassVar[Logger] = Logger("ControlNetUnit") - - # UI only fields. - is_ui: bool = False - input_mode: InputMode = InputMode.SIMPLE - batch_images: Optional[Any] = None - output_dir: str = "" - loopback: bool = False - - # General fields. - enabled: bool = False - module: str = "none" - - @validator("module", always=True, pre=True) - def check_module(cls, value: str) -> str: - if not ControlNetUnit.cls_match_module(value): - raise ValueError(f"module({value}) not found in supported modules.") - return value - - model: str = "None" - - @validator("model", always=True, pre=True) - def check_model(cls, value: str) -> str: - if not ControlNetUnit.cls_match_model(value): - raise ValueError(f"model({value}) not found in supported models.") - return value - - weight: Annotated[float, Field(ge=0.0, le=2.0)] = 1.0 - - # The image to be used for this ControlNetUnit. - image: Optional[Any] = None - - resize_mode: ResizeMode = ResizeMode.INNER_FIT - low_vram: bool = False - processor_res: int = -1 - threshold_a: float = -1 - threshold_b: float = -1 - - @root_validator - def bound_check_params(cls, values: dict) -> dict: - """ - Checks and corrects negative parameters in ControlNetUnit 'unit' in place. - Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to - their default values if negative. - """ - enabled = values.get("enabled") - if not enabled: - return values - - module = values.get("module") - if not module: - return values - - preprocessor = cls.cls_get_preprocessor(module) - assert preprocessor is not None - for unit_param, param in zip( - ("processor_res", "threshold_a", "threshold_b"), - ("slider_resolution", "slider_1", "slider_2"), - ): - value = values.get(unit_param) - cfg = getattr(preprocessor, param) - if value < cfg.minimum or value > cfg.maximum: - values[unit_param] = cfg.value - # Only report warning when non-default value is used. - if value != -1: - cls.cls_logger.info( - f"[{module}.{unit_param}] Invalid value({value}), using default value {cfg.value}." - ) - return values - - guidance_start: Annotated[float, Field(ge=0.0, le=1.0)] = 0.0 - guidance_end: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0 - - @root_validator - def guidance_check(cls, values: dict) -> dict: - start = values.get("guidance_start") - end = values.get("guidance_end") - if start > end: - raise ValueError(f"guidance_start({start}) > guidance_end({end})") - return values - - pixel_perfect: bool = False - control_mode: ControlMode = ControlMode.BALANCED - # Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area` - # in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`. - inpaint_crop_input_image: bool = True - # If hires fix is enabled in A1111, how should this ControlNet unit be applied. - # The value is ignored if the generation is not using hires fix. - hr_option: HiResFixOption = HiResFixOption.BOTH - - # Whether save the detected map of this unit. Setting this option to False prevents saving the - # detected map or sending detected map along with generated images via API. - # Currently the option is only accessible in API calls. - save_detected_map: bool = True - - # Weight for each layer of ControlNet params. - # For ControlNet: - # - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block) - # - SDXL: 10 weights (3 encoder block * 3 + 1 middle block) - # For T2IAdapter - # - SD1.5: 5 weights (4 encoder block + 1 middle block) - # - SDXL: 4 weights (3 encoder block + 1 middle block) - # For IPAdapter - # - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block) - # - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block) - # Note1: Setting advanced weighting will disable `soft_injection`, i.e. - # It is recommended to set ControlMode = BALANCED when using `advanced_weighting`. - # Note2: The field `weight` is still used in some places, e.g. reference_only, - # even advanced_weighting is set. - advanced_weighting: Optional[List[float]] = None - - # The effective region mask that unit's effect should be restricted to. - effective_region_mask: Optional[np.ndarray] = None - - @validator("effective_region_mask", pre=True) - def parse_effective_region_mask(cls, value) -> np.ndarray: - if isinstance(value, str): - return cls.cls_decode_base64(value) - assert isinstance(value, np.ndarray) or value is None - return value - - # The weight mode for PuLID. - # https://github.com/ToTheBeginning/PuLID - pulid_mode: PuLIDMode = PuLIDMode.FIDELITY - - # ------- API only fields ------- - # The tensor input for ipadapter. When this field is set in the API, - # the base64string will be interpret by torch.load to reconstruct ipadapter - # preprocessor output. - # Currently the option is only accessible in API calls. - ipadapter_input: Optional[List[torch.Tensor]] = None - - @validator("ipadapter_input", pre=True) - def parse_ipadapter_input(cls, value) -> Optional[List[torch.Tensor]]: - if value is None: - return None - if isinstance(value, str): - value = [value] - result = [cls.cls_torch_load_base64(b) for b in value] - assert result, "input cannot be empty" - return result - - # The mask to be used on top of the image. - mask: Optional[Any] = None - - @property - def accepts_multiple_inputs(self) -> bool: - """This unit can accept multiple input images.""" - return self.module in ( - "ip-adapter-auto", - "ip-adapter_clip_sdxl", - "ip-adapter_clip_sdxl_plus_vith", - "ip-adapter_clip_sd15", - "ip-adapter_face_id", - "ip-adapter_face_id_plus", - "ip-adapter_pulid", - "instant_id_face_embedding", - ) - - @property - def is_animate_diff_batch(self) -> bool: - return getattr(self, "animatediff_batch", False) - - @property - def uses_clip(self) -> bool: - """Whether this unit uses clip preprocessor.""" - return any( - ( - ("ip-adapter" in self.module and "face_id" not in self.module), - self.module - in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"), - ) - ) - - @property - def is_inpaint(self) -> bool: - return "inpaint" in self.module - - def get_actual_preprocessor(self): - if self.module == "ip-adapter-auto": - return ControlNetUnit.cls_get_preprocessor( - self.module - ).get_preprocessor_by_model(self.model) - return ControlNetUnit.cls_get_preprocessor(self.module) - - @classmethod - def parse_image(cls, image) -> np.ndarray: - if isinstance(image, np.ndarray): - np_image = image - elif isinstance(image, str): - # Necessary for batch. - if os.path.exists(image): - np_image = np.array(Image.open(image)).astype("uint8") - else: - np_image = cls.cls_decode_base64(image) - else: - raise ValueError(f"Unrecognized image format {image}.") - - # [H, W] => [H, W, 3] - if np_image.ndim == 2: - np_image = np.stack([np_image, np_image, np_image], axis=-1) - assert np_image.ndim == 3 - assert np_image.shape[2] == 3 - return np_image - - @classmethod - def combine_image_and_mask( - cls, np_image: np.ndarray, np_mask: Optional[np.ndarray] = None - ) -> np.ndarray: - """RGB + Alpha(Optional) => RGBA""" - # TODO: Change protocol to use 255 as A channel value. - # Note: mask is by default zeros, as both inpaint and - # clip mask does extra work on masked area. - np_mask = (np.zeros_like(np_image) if np_mask is None else np_mask)[:, :, 0:1] - if np_image.shape[:2] != np_mask.shape[:2]: - raise ValueError( - f"image shape ({np_image.shape[:2]}) not aligned with mask shape ({np_mask.shape[:2]})" - ) - return np.concatenate([np_image, np_mask], axis=2) # [H, W, 4] - - @classmethod - def legacy_field_alias(cls, values: dict) -> dict: - ext_compat_keys = { - "guidance": "guidance_end", - "lowvram": "low_vram", - "input_image": "image", - } - for alias, key in ext_compat_keys.items(): - if alias in values: - assert key not in values, f"Conflict of field '{alias}' and '{key}'" - values[key] = alias - cls.cls_logger.warn( - f"Deprecated alias '{alias}' detected. This field will be removed on 2024-06-01" - f"Please use '{key}' instead." - ) - - return values - - @classmethod - def mask_alias(cls, values: dict) -> dict: - """ - Field "mask_image" is the alias of field "mask". - This is for compatibility with SD Forge API. - """ - mask_image = values.get("mask_image") - mask = values.get("mask") - if mask_image is not None: - if mask is not None: - raise ValueError("Cannot specify both 'mask' and 'mask_image'!") - values["mask"] = mask_image - return values - - def get_input_images_rgba(self) -> Optional[List[np.ndarray]]: - """ - RGBA images with potentially different size. - Why we cannot have [B, H, W, C=4] here is that calculation of final - resolution requires generation target's dimensions. - - Parse image with following formats. - API - - image = {"image": base64image, "mask": base64image,} - - image = [image, mask] - - image = (image, mask) - - image = [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...] - - image = base64image, mask = base64image - - UI: - - image = {"image": np_image, "mask": np_image,} - - image = np_image, mask = np_image - """ - init_image = self.image - init_mask = self.mask - - if init_image is None: - assert init_mask is None - return None - - if isinstance(init_image, (list, tuple)): - if not init_image: - raise ValueError(f"{init_image} is not a valid 'image' field value") - if isinstance(init_image[0], dict): - # [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...] - images = init_image - else: - assert len(init_image) == 2 - # [image, mask] - # (image, mask) - images = [ - { - "image": init_image[0], - "mask": init_image[1], - } - ] - elif isinstance(init_image, dict): - # {"image": ..., "mask": ...} - images = [init_image] - elif isinstance(init_image, (str, np.ndarray)): - # image = base64image, mask = base64image - images = [ - { - "image": init_image, - "mask": init_mask, - } - ] - else: - raise ValueError(f"Unrecognized image field {init_image}") - - np_images = [] - for image_dict in images: - assert isinstance(image_dict, dict) - image = image_dict.get("image") - mask = image_dict.get("mask") - assert image is not None - - np_image = self.parse_image(image) - np_mask = self.parse_image(mask) if mask is not None else None - np_images.append(self.combine_image_and_mask(np_image, np_mask)) # [H, W, 4] - - return np_images - - @classmethod - def from_dict(cls, values: dict) -> ControlNetUnit: - values = copy(values) - values = cls.legacy_field_alias(values) - values = cls.mask_alias(values) - return ControlNetUnit(**values) - - @classmethod - def from_infotext_args(cls, *args) -> ControlNetUnit: - assert len(args) == len(ControlNetUnit.infotext_fields()) - return cls.from_dict( - {k: v for k, v in zip(ControlNetUnit.infotext_fields(), args)} - ) - - @staticmethod - def infotext_fields() -> Tuple[str]: - """Fields that should be included in infotext. - You should define a Gradio element with exact same name in ControlNetUiGroup - as well, so that infotext can wire the value to correct field when pasting - infotext. - """ - return ( - "module", - "model", - "weight", - "resize_mode", - "processor_res", - "threshold_a", - "threshold_b", - "guidance_start", - "guidance_end", - "pixel_perfect", - "control_mode", - ) - - def serialize(self) -> str: - """Serialize the unit for infotext.""" - infotext_dict = { - field_to_displaytext(field): serialize_value(getattr(self, field)) - for field in ControlNetUnit.infotext_fields() - } - if not all( - "," not in str(v) and ":" not in str(v) for v in infotext_dict.values() - ): - self.cls_logger.error(f"Unexpected tokens encountered:\n{infotext_dict}") - return "" - - return ", ".join(f"{field}: {value}" for field, value in infotext_dict.items()) - - @classmethod - def parse(cls, text: str) -> ControlNetUnit: - return ControlNetUnit( - enabled=True, - **{ - displaytext_to_field(key): parse_value(value) - for item in text.split(",") - for (key, value) in (item.strip().split(": "),) - }, - ) diff --git a/internal_controlnet/external_code.py b/internal_controlnet/external_code.py index 157f08efc..016142e5a 100644 --- a/internal_controlnet/external_code.py +++ b/internal_controlnet/external_code.py @@ -1,30 +1,58 @@ +import base64 +import io +from dataclasses import dataclass +from enum import Enum from copy import copy from typing import List, Any, Optional, Union, Tuple, Dict +import torch import numpy as np from modules import scripts, processing, shared -from modules.api import api -from .args import ControlNetUnit +from modules.safe import unsafe_torch_load from scripts import global_state from scripts.logging import logger -from scripts.enums import ( - ResizeMode, - BatchOption, # noqa: F401 - ControlMode, # noqa: F401 -) -from scripts.supported_preprocessor import ( - Preprocessor, - PreprocessorParameter, # noqa: F401 -) +from scripts.enums import HiResFixOption +from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter -import torch -import base64 -import io -from modules.safe import unsafe_torch_load +from modules.api import api def get_api_version() -> int: - return 3 + return 2 + + +class ControlMode(Enum): + """ + The improved guess mode. + """ + + BALANCED = "Balanced" + PROMPT = "My prompt is more important" + CONTROL = "ControlNet is more important" + + +class BatchOption(Enum): + DEFAULT = "All ControlNet units for all images in a batch" + SEPARATE = "Each ControlNet unit for each image in a batch" + + +class ResizeMode(Enum): + """ + Resize modes for ControlNet input images. + """ + + RESIZE = "Just Resize" + INNER_FIT = "Crop and Resize" + OUTER_FIT = "Resize and Fill" + + def int_value(self): + if self == ResizeMode.RESIZE: + return 0 + elif self == ResizeMode.INNER_FIT: + return 1 + elif self == ResizeMode.OUTER_FIT: + return 2 + assert False, "NOTREACHED" resize_mode_aliases = { @@ -54,6 +82,15 @@ def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode: return value +def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode: + if isinstance(value, str): + return ControlMode(value) + elif isinstance(value, int): + return [e for e in ControlMode][value] + else: + return value + + def visualize_inpaint_mask(img): if img.ndim == 3 and img.shape[2] == 4: result = img.copy() @@ -115,7 +152,146 @@ def pixel_perfect_resolution( return int(np.round(estimation)) -def to_base64_nparray(encoding: str) -> np.ndarray: +InputImage = Union[np.ndarray, str] +InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage] + + +@dataclass +class ControlNetUnit: + """ + Represents an entire ControlNet processing unit. + """ + + enabled: bool = True + module: str = "none" + model: str = "None" + weight: float = 1.0 + image: Optional[Union[InputImage, List[InputImage]]] = None + resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT + low_vram: bool = False + processor_res: int = -1 + threshold_a: float = -1 + threshold_b: float = -1 + guidance_start: float = 0.0 + guidance_end: float = 1.0 + pixel_perfect: bool = False + control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED + # Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area` + # in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`. + inpaint_crop_input_image: bool = True + # If hires fix is enabled in A1111, how should this ControlNet unit be applied. + # The value is ignored if the generation is not using hires fix. + hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH + + # Whether save the detected map of this unit. Setting this option to False prevents saving the + # detected map or sending detected map along with generated images via API. + # Currently the option is only accessible in API calls. + save_detected_map: bool = True + + # Weight for each layer of ControlNet params. + # For ControlNet: + # - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block) + # - SDXL: 10 weights (3 encoder block * 3 + 1 middle block) + # For T2IAdapter + # - SD1.5: 5 weights (4 encoder block + 1 middle block) + # - SDXL: 4 weights (3 encoder block + 1 middle block) + # For IPAdapter + # - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block) + # - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block) + # Note1: Setting advanced weighting will disable `soft_injection`, i.e. + # It is recommended to set ControlMode = BALANCED when using `advanced_weighting`. + # Note2: The field `weight` is still used in some places, e.g. reference_only, + # even advanced_weighting is set. + advanced_weighting: Optional[List[float]] = None + + # The effective region mask that unit's effect should be restricted to. + effective_region_mask: Optional[np.ndarray] = None + + # The tensor input for ipadapter. When this field is set in the API, + # the base64string will be interpret by torch.load to reconstruct ipadapter + # preprocessor output. + # Currently the option is only accessible in API calls. + ipadapter_input: Optional[List[Any]] = None + + def __eq__(self, other): + if not isinstance(other, ControlNetUnit): + return False + + return vars(self) == vars(other) + + def accepts_multiple_inputs(self) -> bool: + """This unit can accept multiple input images.""" + return self.module in ( + "ip-adapter_clip_sdxl", + "ip-adapter_clip_sdxl_plus_vith", + "ip-adapter_clip_sd15", + "ip-adapter_face_id", + "ip-adapter_face_id_plus", + "instant_id_face_embedding", + ) + + @staticmethod + def infotext_excluded_fields() -> List[str]: + return [ + "image", + "enabled", + # API-only fields. + "advanced_weighting", + "ipadapter_input", + # End of API-only fields. + # Note: "inpaint_crop_image" is img2img inpaint only flag, which does not + # provide much information when restoring the unit. + "inpaint_crop_input_image", + "effective_region_mask", + ] + + @property + def is_animate_diff_batch(self) -> bool: + return getattr(self, "animatediff_batch", False) + + @property + def uses_clip(self) -> bool: + """Whether this unit uses clip preprocessor.""" + return any( + ( + ("ip-adapter" in self.module and "face_id" not in self.module), + self.module + in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"), + ) + ) + + @property + def is_inpaint(self) -> bool: + return "inpaint" in self.module + + def bound_check_params(self) -> None: + """ + Checks and corrects negative parameters in ControlNetUnit 'unit' in place. + Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to + their default values if negative. + """ + preprocessor = Preprocessor.get_preprocessor(self.module) + for unit_param, param in zip( + ("processor_res", "threshold_a", "threshold_b"), + ("slider_resolution", "slider_1", "slider_2"), + ): + value = getattr(self, unit_param) + cfg: PreprocessorParameter = getattr(preprocessor, param) + if value < 0: + setattr(self, unit_param, cfg.value) + logger.info( + f"[{self.module}.{unit_param}] Invalid value({value}), using default value {cfg.value}." + ) + + def get_actual_preprocessor(self) -> Preprocessor: + if self.module == "ip-adapter-auto": + return Preprocessor.get_preprocessor(self.module).get_preprocessor_by_model( + self.model + ) + return Preprocessor.get_preprocessor(self.module) + + +def to_base64_nparray(encoding: str): """ Convert a base64 image into the image type the extension uses """ @@ -220,14 +396,73 @@ def get_max_models_num(): return max_models_num -def to_processing_unit(unit: Union[Dict, ControlNetUnit]) -> ControlNetUnit: +def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit: """ Convert different types to processing unit. + If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details. """ + + ext_compat_keys = { + "guessmode": "guess_mode", + "guidance": "guidance_end", + "lowvram": "low_vram", + "input_image": "image", + } + if isinstance(unit, dict): - return ControlNetUnit.from_dict(unit) + unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()} + + # Handle mask + mask = None + if "mask" in unit: + mask = unit["mask"] + del unit["mask"] + + if "mask_image" in unit: + mask = unit["mask_image"] + del unit["mask_image"] + + if "image" in unit and not isinstance(unit["image"], dict): + unit["image"] = ( + {"image": unit["image"], "mask": mask} + if mask is not None + else unit["image"] if unit["image"] else None + ) + + # Parse ipadapter_input + if "ipadapter_input" in unit and unit["ipadapter_input"] is not None: + + def decode_base64(b: str) -> torch.Tensor: + decoded_bytes = base64.b64decode(b) + return unsafe_torch_load(io.BytesIO(decoded_bytes)) + + if isinstance(unit["ipadapter_input"], str): + unit["ipadapter_input"] = [unit["ipadapter_input"]] - assert isinstance(unit, ControlNetUnit) + unit["ipadapter_input"] = [ + decode_base64(b) for b in unit["ipadapter_input"] + ] + + if unit.get("effective_region_mask", None) is not None: + base64img = unit["effective_region_mask"] + assert isinstance(base64img, str) + unit["effective_region_mask"] = to_base64_nparray(base64img) + + if "guess_mode" in unit: + logger.warning( + "Guess Mode is removed since 1.1.136. Please use Control Mode instead." + ) + + for k in unit.keys(): + if k not in vars(ControlNetUnit): + logger.warn(f"Received unrecognized key '{k}' in API.") + + unit = ControlNetUnit( + **{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()} + ) + + # temporary, check #602 + # assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]' return unit @@ -414,23 +649,3 @@ def is_cn_script(script: scripts.Script) -> bool: """ return script.title().lower() == "controlnet" - - -# TODO: Add model constraint -ControlNetUnit.cls_match_model = lambda model: True -ControlNetUnit.cls_match_module = ( - lambda module: Preprocessor.get_preprocessor(module) is not None -) -ControlNetUnit.cls_get_preprocessor = Preprocessor.get_preprocessor -ControlNetUnit.cls_decode_base64 = to_base64_nparray - - -def decode_base64(b: str) -> torch.Tensor: - decoded_bytes = base64.b64decode(b) - return unsafe_torch_load(io.BytesIO(decoded_bytes)) - - -ControlNetUnit.cls_torch_load_base64 = decode_base64 -ControlNetUnit.cls_logger = logger - -logger.debug("ControlNetUnit initialized") diff --git a/javascript/controlnet_unit.mjs b/javascript/controlnet_unit.mjs index 147ecebaa..53668bb12 100644 --- a/javascript/controlnet_unit.mjs +++ b/javascript/controlnet_unit.mjs @@ -75,6 +75,7 @@ export class ControlNetUnit { this.attachImageUploadListener(); this.attachImageStateChangeObserver(); this.attachA1111SendInfoObserver(); + this.attachPresetDropdownObserver(); } getTabNavButton() { @@ -268,4 +269,24 @@ export class ControlNetUnit { }); } } + + attachPresetDropdownObserver() { + const presetDropDown = this.tab.querySelector('.cnet-preset-dropdown'); + + new MutationObserver((mutationsList) => { + for (const mutation of mutationsList) { + if (mutation.removedNodes.length > 0) { + setTimeout(() => { + this.updateActiveState(); + this.updateActiveUnitCount(); + this.updateActiveControlType(); + }, 1000); + return; + } + } + }).observe(presetDropDown, { + childList: true, + subtree: true, + }); + } } diff --git a/requirements.txt b/requirements.txt index f0072d277..5fb5e2e2b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,3 @@ addict yapf albumentations==1.4.3 matplotlib -facexlib diff --git a/scripts/api.py b/scripts/api.py index fb6df759e..c2d348a3a 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -137,12 +137,12 @@ async def detect( ) unit = ControlNetUnit( - enabled=True, module=preprocessor.label, processor_res=controlnet_processor_res, threshold_a=controlnet_threshold_a, threshold_b=controlnet_threshold_b, ) + unit.bound_check_params() tensors = [] images = [] @@ -179,7 +179,7 @@ def accept(self, json_dict: dict) -> None: low_vram=low_vram, ) if preprocessor.returns_image: - images.append(encode_to_base64(result.display_images[0])) + images.append(encode_to_base64(result.display_image)) else: tensors.append(encode_tensor_to_base64(result.value)) diff --git a/scripts/batch_hijack.py b/scripts/batch_hijack.py index 8e72c9e3b..fe001e610 100644 --- a/scripts/batch_hijack.py +++ b/scripts/batch_hijack.py @@ -1,10 +1,10 @@ import os +from copy import copy from typing import Tuple, List from modules import img2img, processing, shared, script_callbacks from scripts import external_code from scripts.enums import InputMode -from scripts.logging import logger class BatchHijack: def __init__(self): @@ -194,7 +194,7 @@ def unhijack_function(module, name, new_name): def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[List[str]], str, List[str]]: units = external_code.get_all_units_in_processing(p) - units = [unit.copy() for unit in units if getattr(unit, 'enabled', False)] + units = [copy(unit) for unit in units if getattr(unit, 'enabled', False)] any_unit_is_batch = False output_dir = '' input_file_names = [] @@ -222,8 +222,6 @@ def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[ else: batches[i].append(unit.image) - if any_unit_is_batch: - logger.info(f"Batch enabled ({len(batches)})") return any_unit_is_batch, batches, output_dir, input_file_names diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 35b349ff2..3845d537e 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -4,9 +4,10 @@ import logging from collections import OrderedDict from copy import copy, deepcopy -from typing import Dict, Optional, Tuple, List +from typing import Dict, Optional, Tuple, List, Union import modules.scripts as scripts from modules import shared, devices, script_callbacks, processing, masking, images +from modules.api.api import decode_base64_to_image import gradio as gr import time @@ -15,25 +16,14 @@ # Register all preprocessors. import scripts.preprocessor as preprocessor_init # noqa from annotator.util import HWC3 -from internal_controlnet.external_code import ControlNetUnit from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils from scripts.controlnet_lora import bind_control_lora, unbind_control_lora from scripts.controlnet_lllite import clear_all_lllite from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter -from scripts.ipadapter.pulid_attn import PULID_SETTING_FIDELITY, PULID_SETTING_STYLE from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent from scripts.hook import ControlParams, UnetHook, HackedImageRNG -from scripts.enums import ( - ControlModelType, - InputMode, - StableDiffusionVersion, - HiResFixOption, - PuLIDMode, - ControlMode, - BatchOption, - ResizeMode, -) -from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup +from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption +from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit from scripts.controlnet_ui.photopea import Photopea from scripts.logging import logger from scripts.supported_preprocessor import Preprocessor @@ -99,7 +89,44 @@ def swap_img2img_pipeline(p: processing.StableDiffusionProcessingImg2Img): global_state.update_cn_models() -logger.info(f"ControlNet {controlnet_version.version_flag}") + + +def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]: + if image is None: + return None + + if isinstance(image, (tuple, list)): + image = {'image': image[0], 'mask': image[1]} + elif not isinstance(image, dict): + image = {'image': image, 'mask': None} + else: # type(image) is dict + # copy to enable modifying the dict and prevent response serialization error + image = dict(image) + + if isinstance(image['image'], str): + if os.path.exists(image['image']): + image['image'] = np.array(Image.open(image['image'])).astype('uint8') + elif image['image']: + image['image'] = external_code.to_base64_nparray(image['image']) + else: + image['image'] = None + + # If there is no image, return image with None image and None mask + if image['image'] is None: + image['mask'] = None + return image + + if 'mask' not in image or image['mask'] is None: + image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) + elif isinstance(image['mask'], str): + if os.path.exists(image['mask']): + image['mask'] = np.array(Image.open(image['mask']).convert("RGB")).astype('uint8') + elif image['mask']: + image['mask'] = external_code.to_base64_nparray(image['mask']) + else: + image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) + + return image def prepare_mask( @@ -192,7 +219,7 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor: def get_control( p: StableDiffusionProcessing, - unit: ControlNetUnit, + unit: external_code.ControlNetUnit, idx: int, control_model_type: ControlModelType, preprocessor: Preprocessor, @@ -205,12 +232,12 @@ def get_control( h, w, hr_y, hr_x = Script.get_target_dimensions(p) input_image, resize_mode = Script.choose_input_image(p, unit, idx) if isinstance(input_image, list): - assert unit.accepts_multiple_inputs or unit.is_animate_diff_batch + assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch input_images = input_image else: # Following operations are only for single input image. input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode) input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy - if unit.module == 'inpaint_only+lama' and resize_mode == ResizeMode.OUTER_FIT: + if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT: # inpaint_only+lama is special and required outpaint fix _, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x) input_images = [input_image] @@ -252,7 +279,6 @@ def preprocess_input_image(input_image: np.ndarray): ) detected_map = result.value is_image = preprocessor.returns_image - # TODO: Refactor img control detection logic. if high_res_fix: if is_image: hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x) @@ -267,8 +293,7 @@ def preprocess_input_image(input_image: np.ndarray): store_detected_map(detected_map, unit.module) else: control = detected_map - for image in result.display_images: - store_detected_map(image, unit.module) + store_detected_map(input_image, unit.module) if control_model_type == ControlModelType.T2I_StyleAdapter: control = control['last_hidden_state'] @@ -302,11 +327,11 @@ def __init__(self) -> None: self.latest_network = None self.input_image = None self.latest_model_hash = "" - self.enabled_units: List[ControlNetUnit] = [] + self.enabled_units: List[external_code.ControlNetUnit] = [] self.detected_map = [] self.post_processors = [] self.noise_modifier = None - self.ui_batch_option_state = [BatchOption.DEFAULT.value, False] + self.ui_batch_option_state = [external_code.BatchOption.DEFAULT.value, False] batch_hijack.instance.process_batch_callbacks.append(self.batch_tab_process) batch_hijack.instance.process_batch_each_callbacks.append(self.batch_tab_process_each) batch_hijack.instance.postprocess_batch_each_callbacks.insert(0, self.batch_tab_postprocess_each) @@ -318,14 +343,27 @@ def title(self): def show(self, is_img2img): return scripts.AlwaysVisible + @staticmethod + def get_default_ui_unit(is_ui=True): + cls = UiControlNetUnit if is_ui else external_code.ControlNetUnit + return cls( + enabled=False, + module="none", + model="None" + ) + def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ControlNetUiGroup, gr.State]: - group = ControlNetUiGroup(is_img2img, photopea) + group = ControlNetUiGroup( + is_img2img, + Script.get_default_ui_unit(), + photopea, + ) return group, group.render(tabname, elem_id_tabname) def ui_batch_options(self, is_img2img: bool, elem_id_tabname: str): batch_option = gr.Radio( - choices=[e.value for e in BatchOption], - value=BatchOption.DEFAULT.value, + choices=[e.value for e in external_code.BatchOption], + value=external_code.BatchOption.DEFAULT.value, label="Batch Option", elem_id=f"{elem_id_tabname}_controlnet_batch_option_radio", elem_classes="controlnet_batch_option_radio", @@ -478,7 +516,7 @@ def get_element(obj, strict=False): return attribute_value if attribute_value is not None else default @staticmethod - def parse_remote_call(p, unit: ControlNetUnit, idx): + def parse_remote_call(p, unit: external_code.ControlNetUnit, idx): selector = Script.get_remote_call unit.enabled = selector(p, "control_net_enabled", unit.enabled, idx, strict=True) @@ -574,7 +612,7 @@ def high_quality_resize(x, size): return y - if resize_mode == ResizeMode.RESIZE: + if resize_mode == external_code.ResizeMode.RESIZE: detected_map = high_quality_resize(detected_map, (w, h)) detected_map = safe_numpy(detected_map) return get_pytorch_control(detected_map), detected_map @@ -587,7 +625,7 @@ def high_quality_resize(x, size): safeint = lambda x: int(np.round(x)) - if resize_mode == ResizeMode.OUTER_FIT: + if resize_mode == external_code.ResizeMode.OUTER_FIT: k = min(k0, k1) borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0) high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype) @@ -615,31 +653,10 @@ def high_quality_resize(x, size): @staticmethod def get_enabled_units(p): - def unfold_merged(unit: ControlNetUnit) -> List[ControlNetUnit]: - """Unfolds a merged unit to multiple units. Keeps the unit merged for - preprocessors that can accept multiple input images. - """ - if unit.input_mode != InputMode.MERGE: - return [unit] - - if unit.accepts_multiple_inputs: - unit.input_mode = InputMode.SIMPLE - return [unit] - - assert isinstance(unit.image, list) - result = [] - for image in unit.image: - u = unit.copy() - u.image = [image] - u.input_mode = InputMode.SIMPLE - u.weight = unit.weight / len(unit.image) - result.append(u) - return result - units = external_code.get_all_units_in_processing(p) if len(units) == 0: # fill a null group - remote_unit = Script.parse_remote_call(p, ControlNetUnit(), 0) + remote_unit = Script.parse_remote_call(p, Script.get_default_ui_unit(), 0) if remote_unit.enabled: units.append(remote_unit) @@ -648,7 +665,11 @@ def unfold_merged(unit: ControlNetUnit) -> List[ControlNetUnit]: local_unit = Script.parse_remote_call(p, unit, idx) if not local_unit.enabled: continue - enabled_units.extend(unfold_merged(local_unit)) + + if hasattr(local_unit, "unfold_merged"): + enabled_units.extend(local_unit.unfold_merged()) + else: + enabled_units.append(copy(local_unit)) Infotext.write_infotext(enabled_units, p) return enabled_units @@ -656,37 +677,49 @@ def unfold_merged(unit: ControlNetUnit) -> List[ControlNetUnit]: @staticmethod def choose_input_image( p: processing.StableDiffusionProcessing, - unit: ControlNetUnit, + unit: external_code.ControlNetUnit, idx: int - ) -> Tuple[np.ndarray, ResizeMode]: + ) -> Tuple[np.ndarray, external_code.ResizeMode]: """ Choose input image from following sources with descending priority: - p.image_control: [Deprecated] Lagacy way to pass image to controlnet. - p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet. - - unit.image: ControlNet unit input image. - - p.init_images: A1111 img2img input image. + - unit.image: ControlNet tab input image. + - p.init_images: A1111 img2img tab input image. Returns: - The input image in ndarray form. - The resize mode. """ - def from_rgba_to_input(img: np.ndarray) -> np.ndarray: - if ( - shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) or - (img[:, :, 3] <= 5).all() or - (img[:, :, 3] >= 250).all() - ): - # Take RGB - return img[:, :, :3] - logger.info("Canvas scribble mode. Using mask scribble as input.") - return HWC3(img[:, :, 3]) + def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + unit_has_multiple_images = ( + isinstance(unit.image, list) and + len(unit.image) > 0 and + "image" in unit.image[0] + ) + if unit_has_multiple_images: + return [ + d + for img in unit.image + for d in (image_dict_from_any(img),) + if d is not None + ] + return image_dict_from_any(unit.image) + + def decode_image(img) -> np.ndarray: + """Need to check the image for API compatibility.""" + if isinstance(img, str): + return np.asarray(decode_base64_to_image(image['image'])) + else: + assert isinstance(img, np.ndarray) + return img # 4 input image sources. p_image_control = getattr(p, "image_control", None) p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx) - image = unit.get_input_images_rgba() + image = parse_unit_image(unit) a1111_image = getattr(p, "init_images", [None])[0] - resize_mode = unit.resize_mode + resize_mode = external_code.resize_mode_from_value(unit.resize_mode) if batch_hijack.instance.is_batch and p_image_control is not None: logger.warning("Warn: Using legacy field 'p.image_control'.") @@ -699,18 +732,42 @@ def from_rgba_to_input(img: np.ndarray) -> np.ndarray: input_image = np.concatenate([color, alpha], axis=2) else: input_image = HWC3(np.asarray(p_input_image)) - elif image is not None: - assert isinstance(image, list) - # Inpaint mask or CLIP mask. - if unit.is_inpaint or unit.uses_clip: - # RGBA - input_image = image + elif image: + if isinstance(image, list): + # Add mask logic if later there is a processor that accepts mask + # on multiple inputs. + input_image = [HWC3(decode_image(img['image'])) for img in image] + if unit.is_animate_diff_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None: + for idx in range(len(input_image)): + while len(image[idx]['mask'].shape) < 3: + image[idx]['mask'] = image[idx]['mask'][..., np.newaxis] + if unit.is_inpaint or unit.uses_clip: + color = HWC3(image[idx]["image"]) + alpha = image[idx]['mask'][:, :, 0:1] + input_image[idx] = np.concatenate([color, alpha], axis=2) else: - # RGB - input_image = [from_rgba_to_input(img) for img in image] - - if len(input_image) == 1: - input_image = input_image[0] + input_image = HWC3(decode_image(image['image'])) + if 'mask' in image and image['mask'] is not None: + while len(image['mask'].shape) < 3: + image['mask'] = image['mask'][..., np.newaxis] + if unit.is_inpaint or unit.uses_clip: + logger.info("using mask") + color = HWC3(image['image']) + alpha = image['mask'][:, :, 0:1] + input_image = np.concatenate([color, alpha], axis=2) + elif ( + not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and + # There is wield gradio issue that would produce mask that is + # not pure color when no scribble is made on canvas. + # See https://github.com/Mikubill/sd-webui-controlnet/issues/1638. + not ( + (image['mask'][:, :, 0] <= 5).all() or + (image['mask'][:, :, 0] >= 250).all() + ) + ): + logger.info("using mask as input") + input_image = HWC3(image['mask'][:, :, 0]) + unit.module = 'none' # Always use black bg and white line elif a1111_image is not None: input_image = HWC3(np.asarray(a1111_image)) a1111_i2i_resize_mode = getattr(p, "resize_mode", None) @@ -742,9 +799,9 @@ def from_rgba_to_input(img: np.ndarray) -> np.ndarray: @staticmethod def try_crop_image_with_a1111_mask( p: StableDiffusionProcessing, - unit: ControlNetUnit, + unit: external_code.ControlNetUnit, input_image: np.ndarray, - resize_mode: ResizeMode, + resize_mode: external_code.ResizeMode, ) -> np.ndarray: """ Crop ControlNet input image based on A1111 inpaint mask given. @@ -786,7 +843,7 @@ def try_crop_image_with_a1111_mask( input_image = [x.crop(crop_region) for x in input_image] input_image = [ - images.resize_image(ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height) + images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height) for x in input_image ] @@ -795,7 +852,7 @@ def try_crop_image_with_a1111_mask( return input_image @staticmethod - def check_sd_version_compatible(unit: ControlNetUnit) -> None: + def check_sd_version_compatible(unit: external_code.ControlNetUnit) -> None: """ Checks whether the given ControlNet unit has model compatible with the currently active sd model. An exception is thrown if ControlNet unit is detected to be @@ -861,7 +918,7 @@ def controlnet_main_entry(self, p): if not batch_hijack.instance.is_batch: self.enabled_units = Script.get_enabled_units(p) - batch_option_uint_separate = self.ui_batch_option_state[0] == BatchOption.SEPARATE.value + batch_option_uint_separate = self.ui_batch_option_state[0] == external_code.BatchOption.SEPARATE.value batch_option_style_align = self.ui_batch_option_state[1] if len(self.enabled_units) == 0 and not batch_option_style_align: @@ -888,6 +945,7 @@ def controlnet_main_entry(self, p): high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) for idx, unit in enumerate(self.enabled_units): + unit.bound_check_params() Script.check_sd_version_compatible(unit) if ( 'inpaint_only' == unit.module and @@ -941,7 +999,7 @@ def controlnet_main_entry(self, p): elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]: cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None) def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx): - if unit.accepts_multiple_inputs: + if unit.accepts_multiple_inputs(): ip_adapter_image_emb_cond = [] model_net.ipadapter.image_proj_model.to(torch.float32) # noqa for c in cc: @@ -968,7 +1026,7 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files): logger.info(f"\t{frame_idx}: {frame_path}") c = SparseCtrl.create_cond_mask(cn_ad_keyframe_idx, c, p.batch_size).cpu() - elif unit.accepts_multiple_inputs: + elif unit.accepts_multiple_inputs(): # ip-adapter should do prompt travel logger.info("IP-Adapter: control prompts will be traveled in the following way:") for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files): @@ -997,7 +1055,7 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe c_full[cn_ad_keyframe_idx] = c c = c_full # handle batch condition and unconditional - if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs: + if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs(): c = torch.cat([c, c], dim=0) return c @@ -1020,6 +1078,7 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe control_model_type.is_controlnet and model_net.control_model.global_average_pooling ) + control_mode = external_code.control_mode_from_value(unit.control_mode) forward_param = ControlParams( control_model=model_net, preprocessor=preprocessor_dict, @@ -1032,9 +1091,9 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe control_model_type=control_model_type, global_average_pooling=global_average_pooling, hr_hint_cond=hr_control, - hr_option=unit.hr_option if high_res_fix else HiResFixOption.BOTH, - soft_injection=unit.control_mode != ControlMode.BALANCED, - cfg_injection=unit.control_mode == ControlMode.CONTROL, + hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH, + soft_injection=control_mode != external_code.ControlMode.BALANCED, + cfg_injection=control_mode == external_code.ControlMode.CONTROL, effective_region_mask=( get_pytorch_control(unit.effective_region_mask)[:, 0:1, :, :] if unit.effective_region_mask is not None @@ -1131,7 +1190,7 @@ def recolor_intensity_post_processing(x, i): is_low_vram = any(unit.low_vram for unit in self.enabled_units) - for i, (param, unit) in enumerate(zip(forward_params, self.enabled_units)): + for i, param in enumerate(forward_params): if param.control_model_type == ControlModelType.IPAdapter: if param.advanced_weighting is not None: logger.info(f"IP-Adapter using advanced weighting {param.advanced_weighting}") @@ -1146,12 +1205,6 @@ def recolor_intensity_post_processing(x, i): weight = param.weight h, w, hr_y, hr_x = Script.get_target_dimensions(p) - if unit.pulid_mode == PuLIDMode.STYLE: - pulid_attn_setting = PULID_SETTING_STYLE - else: - assert unit.pulid_mode == PuLIDMode.FIDELITY - pulid_attn_setting = PULID_SETTING_FIDELITY - param.control_model.hook( model=unet, preprocessor_outputs=param.hint_cond, @@ -1162,7 +1215,6 @@ def recolor_intensity_post_processing(x, i): latent_width=w // 8, latent_height=h // 8, effective_region_mask=param.effective_region_mask, - pulid_attn_setting=pulid_attn_setting, ) if param.control_model_type == ControlModelType.Controlllite: param.control_model.hook( @@ -1300,7 +1352,7 @@ def batch_tab_process(self, p, batches, *args, **kwargs): unit.batch_images = iter([batch[unit_i] for batch in batches]) def batch_tab_process_each(self, p, *args, **kwargs): - for unit in self.enabled_units: + for unit_i, unit in enumerate(self.enabled_units): if getattr(unit, 'loopback', False) and batch_hijack.instance.batch_index > 0: continue diff --git a/scripts/controlnet_ui/controlnet_ui_group.py b/scripts/controlnet_ui/controlnet_ui_group.py index 908f4f042..2bfc18e22 100644 --- a/scripts/controlnet_ui/controlnet_ui_group.py +++ b/scripts/controlnet_ui/controlnet_ui_group.py @@ -1,8 +1,8 @@ import json import gradio as gr import functools -import itertools -from typing import List, Optional, Union, Dict, Tuple, Literal, Any +from copy import copy +from typing import List, Optional, Union, Dict, Tuple, Literal from dataclasses import dataclass import numpy as np @@ -13,18 +13,12 @@ external_code, ) from annotator.util import HWC3 -from internal_controlnet.external_code import ControlNetUnit from scripts.logging import logger from scripts.controlnet_ui.openpose_editor import OpenposeEditor +from scripts.controlnet_ui.preset import ControlNetPresetUI from scripts.controlnet_ui.photopea import Photopea from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl -from scripts.enums import ( - InputMode, - HiResFixOption, - PuLIDMode, - ControlMode, - ResizeMode, -) +from scripts.enums import InputMode from modules import shared from modules.ui_components import FormRow, FormHTML, ToolButton @@ -127,39 +121,72 @@ def set_component(self, component: gr.components.Component): ) -def create_ui_unit( - input_mode: InputMode = InputMode.SIMPLE, - batch_images: Optional[Any] = None, - output_dir: str = "", - loopback: bool = False, - merge_gallery_files: List[Dict[Union[Literal["name"], Literal["data"]], str]] = [], - use_preview_as_input: bool = False, - generated_image: Optional[np.ndarray] = None, - *args, -) -> ControlNetUnit: - unit_dict = { - k: v - for k, v in zip( - vars(ControlNetUnit()).keys(), - itertools.chain( - [True, input_mode, batch_images, output_dir, loopback], args - ), - ) - } +class UiControlNetUnit(external_code.ControlNetUnit): + """The data class that stores all states of a ControlNetUnit.""" + + def __init__( + self, + input_mode: InputMode = InputMode.SIMPLE, + batch_images: Optional[Union[str, List[external_code.InputImage]]] = None, + output_dir: str = "", + loopback: bool = False, + merge_gallery_files: List[ + Dict[Union[Literal["name"], Literal["data"]], str] + ] = [], + use_preview_as_input: bool = False, + generated_image: Optional[np.ndarray] = None, + mask_image: Optional[np.ndarray] = None, + enabled: bool = True, + module: Optional[str] = None, + model: Optional[str] = None, + weight: float = 1.0, + image: Optional[Dict[str, np.ndarray]] = None, + *args, + **kwargs, + ): + if use_preview_as_input and generated_image is not None: + input_image = generated_image + module = "none" + else: + input_image = image - if use_preview_as_input and generated_image is not None: - input_image = generated_image - unit_dict["module"] = "none" - else: - input_image = unit_dict["image"] + # Prefer uploaded mask_image over hand-drawn mask. + if input_image is not None and mask_image is not None: + assert isinstance(input_image, dict) + input_image["mask"] = mask_image - if merge_gallery_files and input_mode == InputMode.MERGE: - input_image = [ - {"image": read_image(file["name"])} for file in merge_gallery_files - ] + if merge_gallery_files and input_mode == InputMode.MERGE: + input_image = [ + {"image": read_image(file["name"])} for file in merge_gallery_files + ] + + super().__init__(enabled, module, model, weight, input_image, *args, **kwargs) + self.is_ui = True + self.input_mode = input_mode + self.batch_images = batch_images + self.output_dir = output_dir + self.loopback = loopback + + def unfold_merged(self) -> List[external_code.ControlNetUnit]: + """Unfolds a merged unit to multiple units. Keeps the unit merged for + preprocessors that can accept multiple input images. + """ + if self.input_mode != InputMode.MERGE: + return [copy(self)] - unit_dict["image"] = input_image - return ControlNetUnit.from_dict(unit_dict) + if self.accepts_multiple_inputs(): + self.input_mode = InputMode.SIMPLE + return [copy(self)] + + assert isinstance(self.image, list) + result = [] + for image in self.image: + unit = copy(self) + unit.image = image["image"] + unit.input_mode = InputMode.SIMPLE + unit.weight = self.weight / len(self.image) + result.append(unit) + return result class ControlNetUiGroup(object): @@ -193,6 +220,7 @@ class ControlNetUiGroup(object): def __init__( self, is_img2img: bool, + default_unit: external_code.ControlNetUnit, photopea: Optional[Photopea], ): # Whether callbacks have been registered. @@ -201,13 +229,13 @@ def __init__( self.ui_initialized: bool = False self.is_img2img = is_img2img - self.default_unit = ControlNetUnit() + self.default_unit = default_unit self.photopea = photopea self.webcam_enabled = False self.webcam_mirrored = False # Note: All gradio elements declared in `render` will be defined as member variable. - # Update counter to trigger a force update of ControlNetUnit. + # Update counter to trigger a force update of UiControlNetUnit. # This is useful when a field with no event subscriber available changes. # e.g. gr.Gallery, gr.State, etc. self.update_unit_counter = None @@ -216,7 +244,7 @@ def __init__( self.generated_image_group = None self.generated_image = None self.mask_image_group = None - self.effective_region_mask = None + self.mask_image = None self.batch_tab = None self.batch_image_dir = None self.merge_tab = None @@ -254,6 +282,7 @@ def __init__( self.loopback = None self.use_preview_as_input = None self.openpose_editor = None + self.preset_panel = None self.upload_independent_img_in_img2img = None self.image_upload_panel = None self.save_detected_map = None @@ -264,10 +293,10 @@ def __init__( self.batch_image_dir_state = None self.output_dir_state = None self.advanced_weighting = gr.State(None) - self.pulid_mode = None # API-only fields self.ipadapter_input = gr.State(None) + self.effective_region_mask = gr.Image(value=None, visible=False) ControlNetUiGroup.all_ui_groups.append(self) @@ -300,13 +329,11 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: tool="sketch", elem_id=f"{elem_id_tabname}_{tabname}_input_image", elem_classes=["cnet-image"], - brush_color=( - shared.opts.img2img_inpaint_mask_brush_color - if hasattr( - shared.opts, "img2img_inpaint_mask_brush_color" - ) - else None - ), + brush_color=shared.opts.img2img_inpaint_mask_brush_color + if hasattr( + shared.opts, "img2img_inpaint_mask_brush_color" + ) + else None, ) self.image.preprocess = functools.partial( svg_preprocess, preprocess=self.image.preprocess @@ -342,11 +369,11 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: with gr.Group( visible=False, elem_classes=["cnet-mask-image-group"] ) as self.mask_image_group: - self.effective_region_mask = gr.Image( + self.mask_image = gr.Image( value=None, - label="Effective Region Mask", + label="Upload Mask", elem_id=f"{elem_id_tabname}_{tabname}_mask_image", - elem_classes=["cnet-effective-region-mask-image"], + elem_classes=["cnet-mask-image"], interactive=True, ) @@ -454,10 +481,11 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: visible=not self.is_img2img, ) self.mask_upload = gr.Checkbox( - label="Effective Region Mask", + label="Mask Upload", value=False, elem_classes=["cnet-mask-upload"], elem_id=f"{elem_id_tabname}_{tabname}_controlnet_mask_upload_checkbox", + visible=not self.is_img2img, ) self.use_preview_as_input = gr.Checkbox( label="Preview as Input", @@ -487,11 +515,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: ) with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]): - self.type_filter = ( - gr.Dropdown - if shared.opts.data.get("controlnet_control_type_dropdown", False) - else gr.Radio - )( + self.type_filter = (gr.Dropdown if shared.opts.data.get("controlnet_control_type_dropdown", False) else gr.Radio)( Preprocessor.get_all_preprocessor_tags(), label="Control Type", value="All", @@ -585,7 +609,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: ) self.control_mode = gr.Radio( - choices=[e.value for e in ControlMode], + choices=[e.value for e in external_code.ControlMode], value=self.default_unit.control_mode.value, label="Control Mode", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_mode_radio", @@ -593,7 +617,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: ) self.resize_mode = gr.Radio( - choices=[e.value for e in ResizeMode], + choices=[e.value for e in external_code.ResizeMode], value=self.default_unit.resize_mode.value, label="Resize Mode", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_resize_mode_radio", @@ -602,7 +626,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: ) self.hr_option = gr.Radio( - choices=[e.value for e in HiResFixOption], + choices=[e.value for e in external_code.HiResFixOption], value=self.default_unit.hr_option.value, label="Hires-Fix Option", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio", @@ -610,18 +634,9 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: visible=False, ) - self.pulid_mode = gr.Radio( - choices=[e.value for e in PuLIDMode], - value=self.default_unit.pulid_mode.value, - label="PuLID Mode", - elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pulid_mode_radio", - elem_classes="controlnet_pulid_mode_radio", - visible=False, - ) - self.loopback = gr.Checkbox( label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation", - value=False, + value=self.default_unit.loopback, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox", elem_classes="controlnet_loopback_checkbox", visible=False, @@ -629,6 +644,10 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.advanced_weight_control.render() + self.preset_panel = ControlNetPresetUI( + id_prefix=f"{elem_id_tabname}_{tabname}_" + ) + self.batch_image_dir_state = gr.State("") self.output_dir_state = gr.State("") unit_args = ( @@ -642,6 +661,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.merge_gallery, self.use_preview_as_input, self.generated_image, + self.mask_image, # End of Non-persistent fields. self.enabled, self.module, @@ -661,17 +681,34 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.hr_option, self.save_detected_map, self.advanced_weighting, - self.effective_region_mask, - self.pulid_mode, ) - unit = gr.State(ControlNetUnit()) + unit = gr.State(self.default_unit) + for comp in unit_args + (self.update_unit_counter,): + event_subscribers = [] + if hasattr(comp, "edit"): + event_subscribers.append(comp.edit) + elif hasattr(comp, "click"): + event_subscribers.append(comp.click) + elif isinstance(comp, gr.Slider) and hasattr(comp, "release"): + event_subscribers.append(comp.release) + elif hasattr(comp, "change"): + event_subscribers.append(comp.change) + + if hasattr(comp, "clear"): + event_subscribers.append(comp.clear) + + for event_subscriber in event_subscribers: + event_subscriber( + fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit + ) + ( ControlNetUiGroup.a1111_context.img2img_submit_button if self.is_img2img else ControlNetUiGroup.a1111_context.txt2img_submit_button ).click( - fn=create_ui_unit, + fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit, queue=False, @@ -756,12 +793,10 @@ def refresh_all_models(model: str): def register_build_sliders(self): def build_sliders(module: str, pp: bool): preprocessor = Preprocessor.get_preprocessor(module) - slider_resolution_kwargs = ( - preprocessor.slider_resolution.gradio_update_kwargs.copy() - ) + slider_resolution_kwargs = preprocessor.slider_resolution.gradio_update_kwargs.copy() if pp: - slider_resolution_kwargs["visible"] = False + slider_resolution_kwargs['visible'] = False grs = [ gr.update(**slider_resolution_kwargs), @@ -807,7 +842,9 @@ def filter_selected(k: str): gr.Dropdown.update( value=default_option, choices=filtered_preprocessor_list ), - gr.Dropdown.update(value=default_model, choices=filtered_model_list), + gr.Dropdown.update( + value=default_model, choices=filtered_model_list + ), ] self.type_filter.change( @@ -846,9 +883,7 @@ def sd_version_changed(type_filter: str, current_model: str): ) def register_run_annotator(self): - def run_annotator( - image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str - ): + def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str): if image is None: return ( gr.update(value=None, visible=True), @@ -911,16 +946,16 @@ def is_openpose(module: str): and shared.opts.data.get("controlnet_clip_detector_on_cpu", False) ), json_pose_callback=( - json_acceptor.accept if is_openpose(module) else None + json_acceptor.accept + if is_openpose(module) + else None ), model=model, ) return ( # Update to `generated_image` - gr.update( - value=result.display_images[0], visible=True, interactive=False - ), + gr.update(value=result.display_image, visible=True, interactive=False), # preprocessor_preview gr.update(value=True), # openpose editor @@ -935,16 +970,12 @@ def is_openpose(module: str): self.processor_res, self.threshold_a, self.threshold_b, - ( - ControlNetUiGroup.a1111_context.img2img_w_slider - if self.is_img2img - else ControlNetUiGroup.a1111_context.txt2img_w_slider - ), - ( - ControlNetUiGroup.a1111_context.img2img_h_slider - if self.is_img2img - else ControlNetUiGroup.a1111_context.txt2img_h_slider - ), + ControlNetUiGroup.a1111_context.img2img_w_slider + if self.is_img2img + else ControlNetUiGroup.a1111_context.txt2img_w_slider, + ControlNetUiGroup.a1111_context.img2img_h_slider + if self.is_img2img + else ControlNetUiGroup.a1111_context.txt2img_h_slider, self.pixel_perfect, self.resize_mode, self.model, @@ -1091,17 +1122,22 @@ def register_shift_upload_mask(self): else (gr.update(visible=True), gr.update()) ), inputs=[self.mask_upload], - outputs=[self.mask_image_group, self.effective_region_mask], + outputs=[self.mask_image_group, self.mask_image], show_progress=False, ) - def register_shift_pulid_mode(self): - self.model.change( - fn=lambda model: gr.update(visible="pulid" in model.lower()), - inputs=[self.model], - outputs=[self.pulid_mode], - show_progress=False, - ) + if self.upload_independent_img_in_img2img is not None: + self.upload_independent_img_in_img2img.change( + fn=lambda checked: ( + # Uncheck `upload_mask` when not using independent input. + gr.update(visible=False, value=False) + if not checked + else gr.update(visible=True) + ), + inputs=[self.upload_independent_img_in_img2img], + outputs=[self.mask_upload], + show_progress=False, + ) def register_sync_batch_dir(self): def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir): @@ -1205,7 +1241,6 @@ def register_core_callbacks(self): self.register_build_sliders() self.register_shift_preview() self.register_shift_upload_mask() - self.register_shift_pulid_mode() self.register_create_canvas() self.register_clear_preview() self.register_multi_images_upload() @@ -1215,6 +1250,14 @@ def register_core_callbacks(self): self.model, ) assert self.type_filter is not None + self.preset_panel.register_callbacks( + self, + self.type_filter, + *[ + getattr(self, key) + for key in vars(external_code.ControlNetUnit()).keys() + ], + ) self.advanced_weight_control.register_callbacks( self.weight, self.advanced_weighting, diff --git a/scripts/controlnet_ui/preset.py b/scripts/controlnet_ui/preset.py new file mode 100644 index 000000000..3010d2617 --- /dev/null +++ b/scripts/controlnet_ui/preset.py @@ -0,0 +1,305 @@ +import os +import gradio as gr + +from typing import Dict, List + +from modules import scripts +from modules.ui_components import ToolButton +from scripts.infotext import parse_unit, serialize_unit +from scripts.logging import logger +from scripts import external_code +from scripts.supported_preprocessor import Preprocessor + +save_symbol = "\U0001f4be" # 💾 +delete_symbol = "\U0001f5d1\ufe0f" # 🗑️ +refresh_symbol = "\U0001f504" # 🔄 +reset_symbol = "\U000021A9" # ↩ + +NEW_PRESET = "New Preset" + + +def load_presets(preset_dir: str) -> Dict[str, str]: + if not os.path.exists(preset_dir): + os.makedirs(preset_dir) + return {} + + presets = {} + for filename in os.listdir(preset_dir): + if filename.endswith(".txt"): + with open(os.path.join(preset_dir, filename), "r") as f: + name = filename.replace(".txt", "") + if name == NEW_PRESET: + continue + presets[name] = f.read() + return presets + + +def infer_control_type(module: str, model: str) -> str: + p = Preprocessor.get_preprocessor(module) + assert p is not None + matched_tags = [ + tag + for tag in p.tags + if any(f in model.lower() for f in Preprocessor.tag_to_filters(tag)) + ] + if len(matched_tags) != 1: + raise ValueError( + f"Unable to infer control type from module {module} and model {model}" + ) + return matched_tags[0] + + +class ControlNetPresetUI(object): + preset_directory = os.path.join(scripts.basedir(), "presets") + presets = load_presets(preset_directory) + + def __init__(self, id_prefix: str): + with gr.Row(): + self.dropdown = gr.Dropdown( + label="Presets", + show_label=True, + elem_classes=["cnet-preset-dropdown"], + choices=ControlNetPresetUI.dropdown_choices(), + value=NEW_PRESET, + ) + self.reset_button = ToolButton( + value=reset_symbol, + elem_classes=["cnet-preset-reset"], + tooltip="Reset preset", + visible=False, + ) + self.save_button = ToolButton( + value=save_symbol, + elem_classes=["cnet-preset-save"], + tooltip="Save preset", + ) + self.delete_button = ToolButton( + value=delete_symbol, + elem_classes=["cnet-preset-delete"], + tooltip="Delete preset", + ) + self.refresh_button = ToolButton( + value=refresh_symbol, + elem_classes=["cnet-preset-refresh"], + tooltip="Refresh preset", + ) + + with gr.Box( + elem_classes=["popup-dialog", "cnet-preset-enter-name"], + elem_id=f"{id_prefix}_cnet_preset_enter_name", + ) as self.name_dialog: + with gr.Row(): + self.preset_name = gr.Textbox( + label="Preset name", + show_label=True, + lines=1, + elem_classes=["cnet-preset-name"], + ) + self.confirm_preset_name = ToolButton( + value=save_symbol, + elem_classes=["cnet-preset-confirm-name"], + tooltip="Save preset", + ) + + def register_callbacks( + self, + uigroup, + control_type: gr.Radio, + *ui_states, + ): + def apply_preset(name: str, control_type: str, *ui_states): + if name == NEW_PRESET: + return ( + gr.update(visible=False), + *( + (gr.skip(),) + * (len(vars(external_code.ControlNetUnit()).keys()) + 1) + ), + ) + + assert name in ControlNetPresetUI.presets + + infotext = ControlNetPresetUI.presets[name] + preset_unit = parse_unit(infotext) + current_unit = external_code.ControlNetUnit(*ui_states) + preset_unit.image = None + current_unit.image = None + + # Do not compare module param that are not used in preset. + for module_param in ("processor_res", "threshold_a", "threshold_b"): + if getattr(preset_unit, module_param) == -1: + setattr(current_unit, module_param, -1) + + # No update necessary. + if vars(current_unit) == vars(preset_unit): + return ( + gr.update(visible=False), + *( + (gr.skip(),) + * (len(vars(external_code.ControlNetUnit()).keys()) + 1) + ), + ) + + unit = preset_unit + + try: + new_control_type = infer_control_type(unit.module, unit.model) + except ValueError as e: + logger.error(e) + new_control_type = control_type + + return ( + gr.update(visible=True), + gr.update(value=new_control_type), + *[ + gr.update(value=value) if value is not None else gr.update() + for value in vars(unit).values() + ], + ) + + for element, action in ( + (self.dropdown, "change"), + (self.reset_button, "click"), + ): + getattr(element, action)( + fn=apply_preset, + inputs=[self.dropdown, control_type, *ui_states], + outputs=[self.delete_button, control_type, *ui_states], + show_progress="hidden", + ).then( + fn=lambda: gr.update(visible=False), + inputs=None, + outputs=[self.reset_button], + ) + + def save_preset(name: str, *ui_states): + if name == NEW_PRESET: + return gr.update(visible=True), gr.update(), gr.update() + + ControlNetPresetUI.save_preset( + name, external_code.ControlNetUnit(*ui_states) + ) + return ( + gr.update(), # name dialog + gr.update(choices=ControlNetPresetUI.dropdown_choices(), value=name), + gr.update(visible=False), # Reset button + ) + + self.save_button.click( + fn=save_preset, + inputs=[self.dropdown, *ui_states], + outputs=[self.name_dialog, self.dropdown, self.reset_button], + show_progress="hidden", + ).then( + fn=None, + _js=f""" + (name) => {{ + if (name === "{NEW_PRESET}") + popup(gradioApp().getElementById('{self.name_dialog.elem_id}')); + }}""", + inputs=[self.dropdown], + ) + + def delete_preset(name: str): + ControlNetPresetUI.delete_preset(name) + return gr.Dropdown.update( + choices=ControlNetPresetUI.dropdown_choices(), + value=NEW_PRESET, + ), gr.update(visible=False) + + self.delete_button.click( + fn=delete_preset, + inputs=[self.dropdown], + outputs=[self.dropdown, self.reset_button], + show_progress="hidden", + ) + + self.name_dialog.visible = False + + def save_new_preset(new_name: str, *ui_states): + if new_name == NEW_PRESET: + logger.warn(f"Cannot save preset with reserved name '{NEW_PRESET}'") + return gr.update(visible=False), gr.update() + + ControlNetPresetUI.save_preset( + new_name, external_code.ControlNetUnit(*ui_states) + ) + return gr.update(visible=False), gr.update( + choices=ControlNetPresetUI.dropdown_choices(), value=new_name + ) + + self.confirm_preset_name.click( + fn=save_new_preset, + inputs=[self.preset_name, *ui_states], + outputs=[self.name_dialog, self.dropdown], + show_progress="hidden", + ).then(fn=None, _js="closePopup") + + self.refresh_button.click( + fn=ControlNetPresetUI.refresh_preset, + inputs=None, + outputs=[self.dropdown], + show_progress="hidden", + ) + + def update_reset_button(preset_name: str, *ui_states): + if preset_name == NEW_PRESET: + return gr.update(visible=False) + + infotext = ControlNetPresetUI.presets[preset_name] + preset_unit = parse_unit(infotext) + current_unit = external_code.ControlNetUnit(*ui_states) + preset_unit.image = None + current_unit.image = None + + # Do not compare module param that are not used in preset. + for module_param in ("processor_res", "threshold_a", "threshold_b"): + if getattr(preset_unit, module_param) == -1: + setattr(current_unit, module_param, -1) + + return gr.update(visible=vars(current_unit) != vars(preset_unit)) + + for ui_state in ui_states: + if isinstance(ui_state, gr.Image): + continue + + for action in ("edit", "click", "change", "clear", "release"): + if action == "release" and not isinstance(ui_state, gr.Slider): + continue + + if hasattr(ui_state, action): + getattr(ui_state, action)( + fn=update_reset_button, + inputs=[self.dropdown, *ui_states], + outputs=[self.reset_button], + ) + + @staticmethod + def dropdown_choices() -> List[str]: + return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET] + + @staticmethod + def save_preset(name: str, unit: external_code.ControlNetUnit): + infotext = serialize_unit(unit) + with open( + os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w" + ) as f: + f.write(infotext) + + ControlNetPresetUI.presets[name] = infotext + + @staticmethod + def delete_preset(name: str): + if name not in ControlNetPresetUI.presets: + return + + del ControlNetPresetUI.presets[name] + + file = os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt") + if os.path.exists(file): + os.unlink(file) + + @staticmethod + def refresh_preset(): + ControlNetPresetUI.presets = load_presets(ControlNetPresetUI.preset_directory) + return gr.update(choices=ControlNetPresetUI.dropdown_choices()) diff --git a/scripts/controlnet_version.py b/scripts/controlnet_version.py index 5b8222290..34173e04d 100644 --- a/scripts/controlnet_version.py +++ b/scripts/controlnet_version.py @@ -1,4 +1,8 @@ -version_flag = 'v1.1.448' +from scripts.logging import logger + +version_flag = 'v1.1.445' + +logger.info(f"ControlNet {version_flag}") # A smart trick to know if user has updated as well as if user has restarted terminal. # Note that in "controlnet.py" we do NOT use "importlib.reload" to reload this "controlnet_version.py" # This means if user did not completely restart terminal, the "version_flag" will be the previous version. diff --git a/scripts/enums.py b/scripts/enums.py index 477f38ff4..327f36431 100644 --- a/scripts/enums.py +++ b/scripts/enums.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, NamedTuple +from typing import Any, List, NamedTuple from functools import lru_cache @@ -224,6 +224,19 @@ class HiResFixOption(Enum): LOW_RES_ONLY = "Low res only" HIGH_RES_ONLY = "High res only" + @staticmethod + def from_value(value: Any) -> "HiResFixOption": + if isinstance(value, str) and value.startswith("HiResFixOption."): + _, field = value.split(".") + return getattr(HiResFixOption, field) + if isinstance(value, str): + return HiResFixOption(value) + elif isinstance(value, int): + return [x for x in HiResFixOption][value] + else: + assert isinstance(value, HiResFixOption) + return value + class InputMode(Enum): # Single image to a single ControlNet unit. @@ -234,42 +247,3 @@ class InputMode(Enum): # Input is a directory. 1 generation. Each generation takes N input image # from the directory. MERGE = "merge" - - -class PuLIDMode(Enum): - FIDELITY = "Fidelity" - STYLE = "Extremely style" - - -class ControlMode(Enum): - """ - The improved guess mode. - """ - - BALANCED = "Balanced" - PROMPT = "My prompt is more important" - CONTROL = "ControlNet is more important" - - -class BatchOption(Enum): - DEFAULT = "All ControlNet units for all images in a batch" - SEPARATE = "Each ControlNet unit for each image in a batch" - - -class ResizeMode(Enum): - """ - Resize modes for ControlNet input images. - """ - - RESIZE = "Just Resize" - INNER_FIT = "Crop and Resize" - OUTER_FIT = "Resize and Fill" - - def int_value(self): - if self == ResizeMode.RESIZE: - return 0 - elif self == ResizeMode.INNER_FIT: - return 1 - elif self == ResizeMode.OUTER_FIT: - return 2 - assert False, "NOTREACHED" diff --git a/scripts/infotext.py b/scripts/infotext.py index 68202a890..9dbdad1d8 100644 --- a/scripts/infotext.py +++ b/scripts/infotext.py @@ -1,13 +1,60 @@ -from typing import List, Tuple -from enum import Enum +from typing import List, Tuple, Union + import gradio as gr from modules.processing import StableDiffusionProcessing -from internal_controlnet.external_code import ControlNetUnit +from scripts import external_code from scripts.logging import logger +def field_to_displaytext(fieldname: str) -> str: + return " ".join([word.capitalize() for word in fieldname.split("_")]) + + +def displaytext_to_field(text: str) -> str: + return "_".join([word.lower() for word in text.split(" ")]) + + +def parse_value(value: str) -> Union[str, float, int, bool]: + if value in ("True", "False"): + return value == "True" + try: + return int(value) + except ValueError: + try: + return float(value) + except ValueError: + return value # Plain string. + + +def serialize_unit(unit: external_code.ControlNetUnit) -> str: + excluded_fields = external_code.ControlNetUnit.infotext_excluded_fields() + + log_value = { + field_to_displaytext(field): getattr(unit, field) + for field in vars(external_code.ControlNetUnit()).keys() + if field not in excluded_fields and getattr(unit, field) != -1 + # Note: exclude hidden slider values. + } + if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()): + logger.error(f"Unexpected tokens encountered:\n{log_value}") + return "" + + return ", ".join(f"{field}: {value}" for field, value in log_value.items()) + + +def parse_unit(text: str) -> external_code.ControlNetUnit: + return external_code.ControlNetUnit( + enabled=True, + **{ + displaytext_to_field(key): parse_value(value) + for item in text.split(",") + for (key, value) in (item.strip().split(": "),) + }, + ) + + class Infotext(object): def __init__(self) -> None: self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = [] @@ -27,7 +74,11 @@ def register_unit(self, unit_index: int, uigroup) -> None: iocomponents. """ unit_prefix = Infotext.unit_prefix(unit_index) - for field in ControlNetUnit.infotext_fields(): + for field in vars(external_code.ControlNetUnit()).keys(): + # Exclude image for infotext. + if field == "image": + continue + # Every field in ControlNetUnit should have a cooresponding # IOComponent in ControlNetUiGroup. io_component = getattr(uigroup, field) @@ -36,11 +87,13 @@ def register_unit(self, unit_index: int, uigroup) -> None: self.paste_field_names.append(component_locator) @staticmethod - def write_infotext(units: List[ControlNetUnit], p: StableDiffusionProcessing): + def write_infotext( + units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing + ): """Write infotext to `p`.""" p.extra_generation_params.update( { - Infotext.unit_prefix(i): unit.serialize() + Infotext.unit_prefix(i): serialize_unit(unit) for i, unit in enumerate(units) if unit.enabled } @@ -56,19 +109,14 @@ def on_infotext_pasted(infotext: str, results: dict) -> None: assert isinstance(v, str), f"Expect string but got {v}." try: - for field, value in vars(ControlNetUnit.parse(v)).items(): - if field not in ControlNetUnit.infotext_fields(): + for field, value in vars(parse_unit(v)).items(): + if field == "image": continue if value is None: - logger.debug( - f"InfoText: Skipping {field} because value is None." - ) + logger.debug(f"InfoText: Skipping {field} because value is None.") continue component_locator = f"{k} {field}" - if isinstance(value, Enum): - value = value.value - updates[component_locator] = value logger.debug(f"InfoText: Setting {component_locator} = {value}") except Exception as e: diff --git a/scripts/ipadapter/image_proj_models.py b/scripts/ipadapter/image_proj_models.py index d8dd12157..8594ac99b 100644 --- a/scripts/ipadapter/image_proj_models.py +++ b/scripts/ipadapter/image_proj_models.py @@ -269,65 +269,3 @@ def forward(self, x): latents = self.proj_out(latents) return self.norm_out(latents) - - -class PuLIDEncoder(nn.Module): - def __init__(self, width=1280, context_dim=2048, num_token=5): - super().__init__() - self.num_token = num_token - self.context_dim = context_dim - h1 = min((context_dim * num_token) // 4, 1024) - h2 = min((context_dim * num_token) // 2, 1024) - self.body = nn.Sequential( - nn.Linear(width, h1), - nn.LayerNorm(h1), - nn.LeakyReLU(), - nn.Linear(h1, h2), - nn.LayerNorm(h2), - nn.LeakyReLU(), - nn.Linear(h2, context_dim * num_token), - ) - - for i in range(5): - setattr( - self, - f"mapping_{i}", - nn.Sequential( - nn.Linear(1024, 1024), - nn.LayerNorm(1024), - nn.LeakyReLU(), - nn.Linear(1024, 1024), - nn.LayerNorm(1024), - nn.LeakyReLU(), - nn.Linear(1024, context_dim), - ), - ) - - setattr( - self, - f"mapping_patch_{i}", - nn.Sequential( - nn.Linear(1024, 1024), - nn.LayerNorm(1024), - nn.LeakyReLU(), - nn.Linear(1024, 1024), - nn.LayerNorm(1024), - nn.LeakyReLU(), - nn.Linear(1024, context_dim), - ), - ) - - def forward(self, x, y): - # x shape [N, C] - x = self.body(x) - x = x.reshape(-1, self.num_token, self.context_dim) - - hidden_states = () - for i, emb in enumerate(y): - hidden_state = getattr(self, f"mapping_{i}")(emb[:, :1]) + getattr( - self, f"mapping_patch_{i}" - )(emb[:, 1:]).mean(dim=1, keepdim=True) - hidden_states += (hidden_state,) - hidden_states = torch.cat(hidden_states, dim=1) - - return torch.cat([x, hidden_states], dim=1) diff --git a/scripts/ipadapter/ipadapter_model.py b/scripts/ipadapter/ipadapter_model.py index 16d9ac4c5..7314c9b2d 100644 --- a/scripts/ipadapter/ipadapter_model.py +++ b/scripts/ipadapter/ipadapter_model.py @@ -12,7 +12,6 @@ MLPProjModel, MLPProjModelFaceId, ProjModelFaceIdPlus, - PuLIDEncoder, ) @@ -72,7 +71,6 @@ def __init__( is_faceid: bool, is_portrait: bool, is_instantid: bool, - is_pulid: bool, is_v2: bool, ): super().__init__() @@ -87,12 +85,9 @@ def __init__( self.is_v2 = is_v2 self.is_faceid = is_faceid self.is_instantid = is_instantid - self.is_pulid = is_pulid self.clip_extra_context_tokens = 16 if (self.is_plus or is_portrait) else 4 - if self.is_pulid: - self.image_proj_model = PuLIDEncoder() - elif self.is_instantid: + if is_instantid: self.image_proj_model = self.init_proj_instantid() elif is_faceid: self.image_proj_model = self.init_proj_faceid() @@ -240,34 +235,6 @@ def _get_image_embeds_instantid( self.image_proj_model(torch.zeros_like(prompt_image_emb)), ) - def _get_image_embeds_pulid(self, pulid_proj_input) -> ImageEmbed: - """Get image embeds for pulid.""" - id_cond = torch.cat( - [ - pulid_proj_input.id_ante_embedding.to( - device=self.device, dtype=torch.float32 - ), - pulid_proj_input.id_cond_vit.to( - device=self.device, dtype=torch.float32 - ), - ], - dim=-1, - ) - id_vit_hidden = [ - t.to(device=self.device, dtype=torch.float32) - for t in pulid_proj_input.id_vit_hidden - ] - return ImageEmbed( - self.image_proj_model( - id_cond, - id_vit_hidden, - ), - self.image_proj_model( - torch.zeros_like(id_cond), - [torch.zeros_like(t) for t in id_vit_hidden], - ), - ) - @staticmethod def load(state_dict: dict, model_name: str) -> IPAdapterModel: """ @@ -278,7 +245,6 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel: is_v2 = "v2" in model_name is_faceid = "faceid" in model_name is_instantid = "instant_id" in model_name - is_pulid = "pulid" in model_name.lower() is_portrait = "portrait" in model_name is_full = "proj.3.weight" in state_dict["image_proj"] is_plus = ( @@ -290,8 +256,8 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel: sdxl = cross_attention_dim == 2048 sdxl_plus = sdxl and is_plus - if is_instantid or is_pulid: - # InstantID/PuLID does not use clip embedding. + if is_instantid: + # InstantID does not use clip embedding. clip_embeddings_dim = None elif is_faceid: if is_plus: @@ -325,13 +291,10 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel: is_portrait=is_portrait, is_instantid=is_instantid, is_v2=is_v2, - is_pulid=is_pulid, ) def get_image_emb(self, preprocessor_output) -> ImageEmbed: - if self.is_pulid: - return self._get_image_embeds_pulid(preprocessor_output) - elif self.is_instantid: + if self.is_instantid: return self._get_image_embeds_instantid(preprocessor_output) elif self.is_faceid and self.is_plus: # Note: FaceID plus uses both face_embed and clip_embed. diff --git a/scripts/ipadapter/plugable_ipadapter.py b/scripts/ipadapter/plugable_ipadapter.py index 72c0e6652..b56522489 100644 --- a/scripts/ipadapter/plugable_ipadapter.py +++ b/scripts/ipadapter/plugable_ipadapter.py @@ -1,9 +1,8 @@ import itertools import torch import math -from typing import Union, Dict, Optional, Callable +from typing import Union, Dict, Optional -from .pulid_attn import PuLIDAttnSetting from .ipadapter_model import ImageEmbed, IPAdapterModel from ..enums import StableDiffusionVersion, TransformerID @@ -94,7 +93,7 @@ def clear_all_ip_adapter(): class PlugableIPAdapter(torch.nn.Module): def __init__(self, ipadapter: IPAdapterModel): super().__init__() - self.ipadapter: IPAdapterModel = ipadapter + self.ipadapter = ipadapter self.disable_memory_management = True self.dtype = None self.weight: Union[float, Dict[int, float]] = 1.0 @@ -104,7 +103,6 @@ def __init__(self, ipadapter: IPAdapterModel): self.latent_width: int = 0 self.latent_height: int = 0 self.effective_region_mask = None - self.pulid_attn_setting: Optional[PuLIDAttnSetting] = None def reset(self): self.cache = {} @@ -120,7 +118,6 @@ def hook( latent_width: int, latent_height: int, effective_region_mask: Optional[torch.Tensor], - pulid_attn_setting: Optional[PuLIDAttnSetting] = None, dtype=torch.float32, ): global current_model @@ -131,7 +128,6 @@ def hook( self.latent_width = latent_width self.latent_height = latent_height self.effective_region_mask = effective_region_mask - self.pulid_attn_setting = pulid_attn_setting self.cache = {} @@ -190,9 +186,7 @@ def apply_effective_region_mask(self, out: torch.Tensor) -> torch.Tensor: # sequence_length = (latent_height * factor) * (latent_height * factor) # sequence_length = (latent_height * latent_height) * factor ^ 2 factor = math.sqrt(sequence_length / (self.latent_width * self.latent_height)) - assert ( - factor > 0 - ), f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}" + assert factor > 0, f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}" mask_h = int(self.latent_height * factor) mask_w = int(self.latent_width * factor) @@ -205,71 +199,6 @@ def apply_effective_region_mask(self, out: torch.Tensor) -> torch.Tensor: mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2]) return out * mask - def attn_eval( - self, - hidden_states: torch.Tensor, - query: torch.Tensor, - cond_uncond_image_emb: torch.Tensor, - attn_heads: int, - head_dim: int, - emb_to_k: Callable[[torch.Tensor], torch.Tensor], - emb_to_v: Callable[[torch.Tensor], torch.Tensor], - ): - if self.ipadapter.is_pulid: - assert self.pulid_attn_setting is not None - return self.pulid_attn_setting.eval( - hidden_states, - query, - cond_uncond_image_emb, - attn_heads, - head_dim, - emb_to_k, - emb_to_v, - ) - else: - return self._attn_eval_ipadapter( - hidden_states, - query, - cond_uncond_image_emb, - attn_heads, - head_dim, - emb_to_k, - emb_to_v, - ) - - def _attn_eval_ipadapter( - self, - hidden_states: torch.Tensor, - query: torch.Tensor, - cond_uncond_image_emb: torch.Tensor, - attn_heads: int, - head_dim: int, - emb_to_k: Callable[[torch.Tensor], torch.Tensor], - emb_to_v: Callable[[torch.Tensor], torch.Tensor], - ): - assert hidden_states.ndim == 3 - batch_size, sequence_length, inner_dim = hidden_states.shape - ip_k = emb_to_k(cond_uncond_image_emb) - ip_v = emb_to_v(cond_uncond_image_emb) - - ip_k, ip_v = map( - lambda t: t.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2), - (ip_k, ip_v), - ) - assert ip_k.dtype == ip_v.dtype - - # On MacOS, q can be float16 instead of float32. - # https://github.com/Mikubill/sd-webui-controlnet/issues/2208 - if query.dtype != ip_k.dtype: - ip_k = ip_k.to(dtype=query.dtype) - ip_v = ip_v.to(dtype=query.dtype) - - ip_out = torch.nn.functional.scaled_dot_product_attention( - query, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False - ) - ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) - return ip_out - @torch.no_grad() def patch_forward(self, number: int, transformer_index: int): @torch.no_grad() @@ -291,15 +220,27 @@ def forward(attn_blk, x, q): k_key = f"{number * 2 + 1}_to_k_ip" v_key = f"{number * 2 + 1}_to_v_ip" - ip_out = self.attn_eval( - hidden_states=x, - query=q, - cond_uncond_image_emb=self.image_emb.eval(current_model.cond_mark), - attn_heads=h, - head_dim=head_dim, - emb_to_k=lambda emb: self.call_ip(k_key, emb, device=q.device), - emb_to_v=lambda emb: self.call_ip(v_key, emb, device=q.device), + cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark) + ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device) + ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device) + + ip_k, ip_v = map( + lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), + (ip_k, ip_v), ) + assert ip_k.dtype == ip_v.dtype + + # On MacOS, q can be float16 instead of float32. + # https://github.com/Mikubill/sd-webui-controlnet/issues/2208 + if q.dtype != ip_k.dtype: + ip_k = ip_k.to(dtype=q.dtype) + ip_v = ip_v.to(dtype=q.dtype) + + ip_out = torch.nn.functional.scaled_dot_product_attention( + q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False + ) + ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) + return self.apply_effective_region_mask(ip_out * weight) return forward diff --git a/scripts/ipadapter/presets.py b/scripts/ipadapter/presets.py index 764c83c98..275f70be5 100644 --- a/scripts/ipadapter/presets.py +++ b/scripts/ipadapter/presets.py @@ -166,12 +166,6 @@ def match_model(model_name: str) -> IPAdapterPreset: model="ip-adapter-faceid-portrait_sdxl", sd_version=StableDiffusionVersion.SDXL, ), - IPAdapterPreset( - name="pulid", - module="ip-adapter_pulid", - model="ip-adapter_pulid_sdxl_fp16", - sd_version=StableDiffusionVersion.SDXL, - ), ] _preset_by_model = {p.model: p for p in ipadapter_presets} diff --git a/scripts/ipadapter/pulid_attn.py b/scripts/ipadapter/pulid_attn.py deleted file mode 100644 index e2823470c..000000000 --- a/scripts/ipadapter/pulid_attn.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -import torch.nn.functional as F -from dataclasses import dataclass -from typing import Callable - - -@dataclass -class PuLIDAttnSetting: - num_zero: int = 0 - ortho: bool = False - ortho_v2: bool = False - - def eval( - self, - hidden_states: torch.Tensor, - query: torch.Tensor, - id_embedding: torch.Tensor, - attn_heads: int, - head_dim: int, - id_to_k: Callable[[torch.Tensor], torch.Tensor], - id_to_v: Callable[[torch.Tensor], torch.Tensor], - ): - assert hidden_states.ndim == 3 - batch_size, sequence_length, inner_dim = hidden_states.shape - - if self.num_zero == 0: - id_key = id_to_k(id_embedding).to(query.dtype) - id_value = id_to_v(id_embedding).to(query.dtype) - else: - zero_tensor = torch.zeros( - (id_embedding.size(0), self.num_zero, id_embedding.size(-1)), - dtype=id_embedding.dtype, - device=id_embedding.device, - ) - id_key = id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to( - query.dtype - ) - id_value = id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to( - query.dtype - ) - - id_key = id_key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - id_value = id_value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - id_hidden_states = F.scaled_dot_product_attention( - query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - id_hidden_states = id_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn_heads * head_dim - ) - id_hidden_states = id_hidden_states.to(query.dtype) - - if not self.ortho and not self.ortho_v2: - return id_hidden_states - elif self.ortho_v2: - orig_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - id_hidden_states = id_hidden_states.to(torch.float32) - attn_map = query @ id_key.transpose(-2, -1) - attn_mean = attn_map.softmax(dim=-1).mean(dim=1) - attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) - projection = ( - torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True) - / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) - * hidden_states - ) - orthogonal = id_hidden_states + (attn_mean - 1) * projection - return orthogonal.to(orig_dtype) - else: - orig_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - id_hidden_states = id_hidden_states.to(torch.float32) - projection = ( - torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True) - / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) - * hidden_states - ) - orthogonal = id_hidden_states - projection - return orthogonal.to(orig_dtype) - - -PULID_SETTING_FIDELITY = PuLIDAttnSetting( - num_zero=8, - ortho=False, - ortho_v2=True, -) - -PULID_SETTING_STYLE = PuLIDAttnSetting( - num_zero=16, - ortho=True, - ortho_v2=False, -) diff --git a/scripts/preprocessor/__init__.py b/scripts/preprocessor/__init__.py index b330e73ce..6bbcb762f 100644 --- a/scripts/preprocessor/__init__.py +++ b/scripts/preprocessor/__init__.py @@ -1,4 +1,3 @@ -from .pulid import * from .inpaint import * from .lama_inpaint import * from .ip_adapter_auto import * diff --git a/scripts/preprocessor/inpaint.py b/scripts/preprocessor/inpaint.py index 25874196a..4605dc252 100644 --- a/scripts/preprocessor/inpaint.py +++ b/scripts/preprocessor/inpaint.py @@ -1,7 +1,18 @@ -from scripts.utils import visualize_inpaint_mask +import numpy as np + from ..supported_preprocessor import Preprocessor, PreprocessorParameter +def visualize_inpaint_mask(img): + if img.ndim == 3 and img.shape[2] == 4: + result = img.copy() + mask = result[:, :, 3] + mask = 255 - mask // 2 + result[:, :, 3] = mask + return np.ascontiguousarray(result.copy()) + return img + + class PreprocessorInpaint(Preprocessor): def __init__(self): super().__init__(name="inpaint") @@ -12,6 +23,9 @@ def __init__(self): self.accepts_mask = True self.requires_mask = True + def get_display_image(self, input_image: np.ndarray, result): + return visualize_inpaint_mask(result) + def __call__( self, input_image, @@ -21,10 +35,7 @@ def __call__( slider_3=None, **kwargs ): - return Preprocessor.Result( - value=input_image, - display_images=visualize_inpaint_mask(input_image)[None, :, :, :], - ) + return input_image class PreprocessorInpaintOnly(Preprocessor): @@ -36,6 +47,9 @@ def __init__(self): self.accepts_mask = True self.requires_mask = True + def get_display_image(self, input_image: np.ndarray, result): + return visualize_inpaint_mask(result) + def __call__( self, input_image, @@ -45,10 +59,7 @@ def __call__( slider_3=None, **kwargs ): - return Preprocessor.Result( - value=input_image, - display_images=visualize_inpaint_mask(input_image)[None, :, :, :], - ) + return input_image Preprocessor.add_supported_preprocessor(PreprocessorInpaint()) diff --git a/scripts/preprocessor/lama_inpaint.py b/scripts/preprocessor/lama_inpaint.py index 1cd1c521c..33aff60bf 100644 --- a/scripts/preprocessor/lama_inpaint.py +++ b/scripts/preprocessor/lama_inpaint.py @@ -2,7 +2,7 @@ import numpy as np from ..supported_preprocessor import Preprocessor, PreprocessorParameter -from ..utils import resize_image_with_pad, visualize_inpaint_mask +from ..utils import resize_image_with_pad class PreprocessorLamaInpaint(Preprocessor): @@ -15,6 +15,12 @@ def __init__(self): self.accepts_mask = True self.requires_mask = True + def get_display_image(self, input_image: np.ndarray, result: np.ndarray): + """For lama inpaint, display image should not contain mask.""" + assert result.ndim == 3 + assert result.shape[2] == 4 + return result[:, :, :3] + def __call__( self, input_image, @@ -50,13 +56,7 @@ def __call__( fin_color = fin_color.clip(0, 255).astype(np.uint8) result = np.concatenate([fin_color, raw_mask], axis=2) - return Preprocessor.Result( - value=result, - display_images=[ - result[:, :, :3], - visualize_inpaint_mask(result), - ], - ) + return result Preprocessor.add_supported_preprocessor(PreprocessorLamaInpaint()) diff --git a/scripts/preprocessor/legacy/legacy_preprocessors.py b/scripts/preprocessor/legacy/legacy_preprocessors.py index 7c5e1c873..902e6c9d0 100644 --- a/scripts/preprocessor/legacy/legacy_preprocessors.py +++ b/scripts/preprocessor/legacy/legacy_preprocessors.py @@ -93,7 +93,7 @@ def unload(self): def __call__( self, input_image, - resolution=512, + resolution, slider_1=None, slider_2=None, slider_3=None, diff --git a/scripts/preprocessor/model_free_preprocessors.py b/scripts/preprocessor/model_free_preprocessors.py index 54e6c1f6d..6b333e359 100644 --- a/scripts/preprocessor/model_free_preprocessors.py +++ b/scripts/preprocessor/model_free_preprocessors.py @@ -93,7 +93,7 @@ class PreprocessorBlurGaussian(Preprocessor): def __init__(self): super().__init__(name="blur_gaussian") self.slider_1 = PreprocessorParameter( - label="Sigma", minimum=0.01, maximum=64.0, value=9.0 + label="Sigma", minimum=64, maximum=2048, value=512 ) self.tags = ["Tile"] diff --git a/scripts/preprocessor/pulid.py b/scripts/preprocessor/pulid.py deleted file mode 100644 index a46f91290..000000000 --- a/scripts/preprocessor/pulid.py +++ /dev/null @@ -1,169 +0,0 @@ -# https://github.com/ToTheBeginning/PuLID - -import torch -import cv2 -import numpy as np -from typing import Optional, List -from dataclasses import dataclass -from facexlib.parsing import init_parsing_model -from facexlib.utils.face_restoration_helper import FaceRestoreHelper -from torchvision.transforms.functional import normalize - -from ..supported_preprocessor import Preprocessor, PreprocessorParameter -from scripts.utils import npimg2tensor, tensor2npimg, resize_image_with_pad - - -def to_gray(img): - x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] - x = x.repeat(1, 3, 1, 1) - return x - - -class PreprocessorFaceXLib(Preprocessor): - def __init__(self): - super().__init__(name="facexlib") - self.tags = [] - self.slider_resolution = PreprocessorParameter(visible=False) - self.model: Optional[FaceRestoreHelper] = None - - def load_model(self): - if self.model is None: - self.model = FaceRestoreHelper( - upscale_factor=1, - face_size=512, - crop_ratio=(1, 1), - det_model="retinaface_resnet50", - save_ext="png", - device=self.device, - ) - self.model.face_parse = init_parsing_model( - model_name="bisenet", device=self.device - ) - self.model.face_parse.to(device=self.device) - self.model.face_det.to(device=self.device) - return self.model - - def unload(self) -> bool: - """@Override""" - if self.model is not None: - self.model.face_parse.to(device="cpu") - self.model.face_det.to(device="cpu") - return True - return False - - def __call__( - self, - input_image, - resolution=512, - slider_1=None, - slider_2=None, - slider_3=None, - input_mask=None, - return_tensor=False, - **kwargs - ): - """ - @Override - Returns black and white face features image with background removed. - """ - self.load_model() - self.model.clean_all() - input_image, _ = resize_image_with_pad(input_image, resolution) - # using facexlib to detect and align face - image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) - self.model.read_image(image_bgr) - self.model.get_face_landmarks_5(only_center_face=True) - self.model.align_warp_face() - if len(self.model.cropped_faces) == 0: - raise RuntimeError("facexlib align face fail") - align_face = self.model.cropped_faces[0] - align_face_rgb = cv2.cvtColor(align_face, cv2.COLOR_BGR2RGB) - input = npimg2tensor(align_face_rgb) - input = input.to(self.device) - parsing_out = self.model.face_parse( - normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - )[0] - parsing_out = parsing_out.argmax(dim=1, keepdim=True) - bg_label = [0, 16, 18, 7, 8, 9, 14, 15] - bg = sum(parsing_out == i for i in bg_label).bool() - white_image = torch.ones_like(input) - # only keep the face features - face_features_image = torch.where(bg, white_image, to_gray(input)) - if return_tensor: - return face_features_image - else: - return tensor2npimg(face_features_image) - - -@dataclass -class PuLIDProjInput: - id_ante_embedding: torch.Tensor - id_cond_vit: torch.Tensor - id_vit_hidden: List[torch.Tensor] - - -class PreprocessorPuLID(Preprocessor): - """PuLID preprocessor.""" - - def __init__(self): - super().__init__(name="ip-adapter_pulid") - self.tags = ["IP-Adapter"] - self.slider_resolution = PreprocessorParameter(visible=False) - self.returns_image = False - self.preprocessors_deps = [ - "facexlib", - "instant_id_face_embedding", - "EVA02-CLIP-L-14-336", - ] - - def facexlib_detect(self, input_image: np.ndarray) -> torch.Tensor: - facexlib_preprocessor = Preprocessor.get_preprocessor("facexlib") - return facexlib_preprocessor(input_image, return_tensor=True) - - def insightface_antelopev2_detect(self, input_image: np.ndarray) -> torch.Tensor: - antelopev2_preprocessor = Preprocessor.get_preprocessor( - "instant_id_face_embedding" - ) - return antelopev2_preprocessor(input_image) - - def unload(self) -> bool: - unloaded = False - for p_name in self.preprocessors_deps: - p = Preprocessor.get_preprocessor(p_name) - if p is not None: - unloaded = unloaded or p.unload() - return unloaded - - def __call__( - self, - input_image, - resolution, - slider_1=None, - slider_2=None, - slider_3=None, - input_mask=None, - **kwargs - ) -> Preprocessor.Result: - id_ante_embedding = self.insightface_antelopev2_detect(input_image) - if id_ante_embedding.ndim == 1: - id_ante_embedding = id_ante_embedding.unsqueeze(0) - - face_features_image = self.facexlib_detect(input_image) - evaclip_preprocessor = Preprocessor.get_preprocessor("EVA02-CLIP-L-14-336") - assert ( - evaclip_preprocessor is not None - ), "EVA02-CLIP-L-14-336 preprocessor not found! Please install sd-webui-controlnet-evaclip" - r = evaclip_preprocessor(face_features_image) - - return Preprocessor.Result( - value=PuLIDProjInput( - id_ante_embedding=id_ante_embedding, - id_cond_vit=r.id_cond_vit, - id_vit_hidden=r.id_vit_hidden, - ), - display_images=[tensor2npimg(face_features_image)], - ) - - -Preprocessor.add_supported_preprocessor(PreprocessorFaceXLib()) -Preprocessor.add_supported_preprocessor(PreprocessorPuLID()) diff --git a/scripts/supported_preprocessor.py b/scripts/supported_preprocessor.py index 473d6203c..caf5a6a78 100644 --- a/scripts/supported_preprocessor.py +++ b/scripts/supported_preprocessor.py @@ -4,7 +4,7 @@ import numpy as np import torch -from modules import shared, devices +from modules import shared from scripts.logging import logger from scripts.utils import ndarray_lru_cache @@ -101,7 +101,6 @@ class Preprocessor(ABC): use_soft_projection_in_hr_fix = False expand_mask_when_resize_and_fill = False model: Optional[torch.nn.Module] = None - device = devices.get_device_for("controlnet") all_processors: ClassVar[Dict[str, "Preprocessor"]] = {} all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {} @@ -184,19 +183,18 @@ def unload_unused(cls, active_processors: Set["Preprocessor"]): class Result(NamedTuple): value: Any - # The display images shown on UI. - display_images: List[np.ndarray] + # The display image shown on UI. + display_image: np.ndarray + + def get_display_image(self, input_image: np.ndarray, result): + return result if self.returns_image else input_image def cached_call(self, input_image, *args, **kwargs) -> "Preprocessor.Result": """The function exposed that also returns an image for display.""" result = self._cached_call(input_image, *args, **kwargs) - if isinstance(result, Preprocessor.Result): - return result - else: - return Preprocessor.Result( - value=result, - display_images=[result if self.returns_image else input_image], - ) + return Preprocessor.Result( + value=result, display_image=self.get_display_image(input_image, result) + ) @ndarray_lru_cache(max_size=CACHE_SIZE) def _cached_call(self, *args, **kwargs): diff --git a/scripts/utils.py b/scripts/utils.py index e660279a9..c26750f14 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -1,4 +1,3 @@ -from einops import rearrange import torch import os import functools @@ -106,9 +105,8 @@ def wrapper(*args, **kwargs): class TimeMeta(type): - """Metaclass to record execution time on all methods of the - child class.""" - + """ Metaclass to record execution time on all methods of the + child class. """ def __new__(cls, name, bases, attrs): for attr_name, attr_value in attrs.items(): if callable(attr_value): @@ -163,9 +161,7 @@ def read_image(img_path: str) -> str: return encoded_image -def read_image_dir( - img_dir: str, suffixes=(".png", ".jpg", ".jpeg", ".webp") -) -> List[str]: +def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]: """Try read all images in given img_dir.""" images = [] for filename in os.listdir(img_dir): @@ -179,7 +175,7 @@ def read_image_dir( def align_dim_latent(x: int) -> int: - """Align the pixel dimension (w/h) to latent dimension. + """ Align the pixel dimension (w/h) to latent dimension. Stable diffusion 1:8 ratio for latent/pixel, i.e., 1 latent unit == 8 pixel unit.""" return (x // 8) * 8 @@ -207,34 +203,9 @@ def resize_image_with_pad(img: np.ndarray, resolution: int): W_target = int(np.round(float(W_raw) * k)) img = cv2.resize(img, (W_target, H_target), interpolation=interpolation) H_pad, W_pad = pad64(H_target), pad64(W_target) - img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode="edge") + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge') def remove_pad(x): return safer_memory(x[:H_target, :W_target]) - return safer_memory(img_padded), remove_pad - - -def npimg2tensor(img: np.ndarray) -> torch.Tensor: - """Convert numpy img ([H, W, C]) to tensor ([1, C, H, W])""" - return rearrange(torch.from_numpy(img).float() / 255.0, "h w c -> 1 c h w") - - -def tensor2npimg(t: torch.Tensor) -> np.ndarray: - """Convert tensor ([1, C, H, W]) to numpy RGB img ([H, W, C])""" - return ( - (rearrange(t, "1 c h w -> h w c") * 255.0) - .to(dtype=torch.uint8) - .cpu() - .numpy() - ) - - -def visualize_inpaint_mask(img): - if img.ndim == 3 and img.shape[2] == 4: - result = img.copy() - mask = result[:, :, 3] - mask = 255 - mask // 2 - result[:, :, 3] = mask - return np.ascontiguousarray(result.copy()) - return img + return safer_memory(img_padded), remove_pad \ No newline at end of file diff --git a/tests/cn_script/batch_hijack_test.py b/tests/cn_script/batch_hijack_test.py index b8c1cc444..0f68fe5bc 100644 --- a/tests/cn_script/batch_hijack_test.py +++ b/tests/cn_script/batch_hijack_test.py @@ -1,4 +1,3 @@ -import numpy as np import unittest.mock import importlib from typing import Any @@ -7,18 +6,13 @@ from modules import processing, scripts, shared -from internal_controlnet.external_code import ControlNetUnit -from scripts import controlnet, batch_hijack +from scripts import controlnet, external_code, batch_hijack batch_hijack.instance.undo_hijack() original_process_images_inner = processing.process_images_inner -def create_unit(**kwargs) -> ControlNetUnit: - return ControlNetUnit(enabled=True, **kwargs) - - class TestBatchHijack(unittest.TestCase): @unittest.mock.patch('modules.script_callbacks.on_script_unloaded') def setUp(self, on_script_unloaded_mock): @@ -64,18 +58,9 @@ def assert_get_cn_batches_works(self, batch_images_list): is_cn_batch, batches, output_dir, _ = batch_hijack.get_cn_batches(self.p) batch_hijack.instance.dispatch_callbacks(batch_hijack.instance.process_batch_callbacks, self.p, batches, output_dir) - batch_units = [ - unit - for unit in self.p.script_args - if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH - ] - # Convert iterator to list to avoid double eval of iterator exhausting - # the iterator in following checks. - for unit in batch_units: - unit.batch_images = list(unit.batch_images) - + batch_units = [unit for unit in self.p.script_args if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH] if batch_units: - self.assertEqual(min(len(unit.batch_images) for unit in batch_units), len(batches)) + self.assertEqual(min(len(list(unit.batch_images)) for unit in batch_units), len(batches)) else: self.assertEqual(1, len(batches)) @@ -88,15 +73,15 @@ def test_get_cn_batches__empty(self): self.assertEqual(is_batch, False) def test_get_cn_batches__1_simple(self): - self.p.script_args.append(create_unit(image=get_dummy_image())) + self.p.script_args.append(external_code.ControlNetUnit(image=get_dummy_image())) self.assert_get_cn_batches_works([ - [get_dummy_image()], + [self.p.script_args[0].image], ]) def test_get_cn_batches__2_simples(self): self.p.script_args.extend([ - create_unit(image=get_dummy_image(0)), - create_unit(image=get_dummy_image(1)), + external_code.ControlNetUnit(image=get_dummy_image(0)), + external_code.ControlNetUnit(image=get_dummy_image(1)), ]) self.assert_get_cn_batches_works([ [get_dummy_image(0)], @@ -105,7 +90,7 @@ def test_get_cn_batches__2_simples(self): def test_get_cn_batches__1_batch(self): self.p.script_args.extend([ - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), @@ -122,14 +107,14 @@ def test_get_cn_batches__1_batch(self): def test_get_cn_batches__2_batches(self): self.p.script_args.extend([ - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), get_dummy_image(1), ], ), - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(2), @@ -150,8 +135,8 @@ def test_get_cn_batches__2_batches(self): def test_get_cn_batches__2_mixed(self): self.p.script_args.extend([ - create_unit(image=get_dummy_image(0)), - create_unit( + external_code.ControlNetUnit(image=get_dummy_image(0)), + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(1), @@ -172,8 +157,8 @@ def test_get_cn_batches__2_mixed(self): def test_get_cn_batches__3_mixed(self): self.p.script_args.extend([ - create_unit(image=get_dummy_image(0)), - create_unit( + external_code.ControlNetUnit(image=get_dummy_image(0)), + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(1), @@ -181,7 +166,7 @@ def test_get_cn_batches__3_mixed(self): get_dummy_image(3), ], ), - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(4), @@ -257,14 +242,14 @@ def test_process_images_no_units_forwards(self): def test_process_images__only_simple_units__forwards(self): self.p.script_args = [ - create_unit(image=get_dummy_image()), - create_unit(image=get_dummy_image()), + external_code.ControlNetUnit(image=get_dummy_image()), + external_code.ControlNetUnit(image=get_dummy_image()), ] self.assert_process_images_hijack_called(batch_count=0) def test_process_images__1_batch_1_unit__runs_1_batch(self): self.p.script_args = [ - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(), @@ -275,7 +260,7 @@ def test_process_images__1_batch_1_unit__runs_1_batch(self): def test_process_images__2_batches_1_unit__runs_2_batches(self): self.p.script_args = [ - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), @@ -288,7 +273,7 @@ def test_process_images__2_batches_1_unit__runs_2_batches(self): def test_process_images__8_batches_1_unit__runs_8_batches(self): batch_count = 8 self.p.script_args = [ - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[get_dummy_image(i) for i in range(batch_count)] ), @@ -297,11 +282,11 @@ def test_process_images__8_batches_1_unit__runs_8_batches(self): def test_process_images__1_batch_2_units__runs_1_batch(self): self.p.script_args = [ - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[get_dummy_image(0)] ), - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[get_dummy_image(1)] ), @@ -310,14 +295,14 @@ def test_process_images__1_batch_2_units__runs_1_batch(self): def test_process_images__2_batches_2_units__runs_2_batches(self): self.p.script_args = [ - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), get_dummy_image(1), ], ), - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(2), @@ -329,7 +314,7 @@ def test_process_images__2_batches_2_units__runs_2_batches(self): def test_process_images__3_batches_2_mixed_units__runs_3_batches(self): self.p.script_args = [ - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), @@ -337,7 +322,7 @@ def test_process_images__3_batches_2_mixed_units__runs_3_batches(self): get_dummy_image(2), ], ), - create_unit( + controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.SIMPLE, image=get_dummy_image(3), ), diff --git a/tests/cn_script/cn_script_test.py b/tests/cn_script/cn_script_test.py index 1ff7d526d..47a9ae2a2 100644 --- a/tests/cn_script/cn_script_test.py +++ b/tests/cn_script/cn_script_test.py @@ -7,9 +7,8 @@ utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils") -from scripts.enums import ResizeMode +from scripts import external_code from scripts.controlnet import prepare_mask, Script, set_numpy_seed -from internal_controlnet.external_code import ControlNetUnit from modules import processing @@ -118,14 +117,16 @@ class TestScript(unittest.TestCase): "AAAAAAAAAAAAAAAAAAAAAAAA/wZOlAAB5tU+nAAAAABJRU5ErkJggg==" ) - sample_np_image = np.zeros(shape=[8, 8, 3], dtype=np.uint8) + sample_np_image = np.array( + [[100, 200, 50], [150, 75, 225], [30, 120, 180]], dtype=np.uint8 + ) def test_choose_input_image(self): with self.subTest(name="no image"): with self.assertRaises(ValueError): Script.choose_input_image( p=processing.StableDiffusionProcessing(), - unit=ControlNetUnit(), + unit=external_code.ControlNetUnit(), idx=0, ) @@ -133,30 +134,30 @@ def test_choose_input_image(self): _, resize_mode = Script.choose_input_image( p=MockImg2ImgProcessing( init_images=[TestScript.sample_np_image], - resize_mode=ResizeMode.OUTER_FIT, + resize_mode=external_code.ResizeMode.OUTER_FIT, ), - unit=ControlNetUnit( - image=TestScript.sample_np_image, + unit=external_code.ControlNetUnit( + image=TestScript.sample_base64_image, module="none", - resize_mode=ResizeMode.INNER_FIT, + resize_mode=external_code.ResizeMode.INNER_FIT, ), idx=0, ) - self.assertEqual(resize_mode, ResizeMode.INNER_FIT) + self.assertEqual(resize_mode, external_code.ResizeMode.INNER_FIT) with self.subTest(name="A1111 input"): _, resize_mode = Script.choose_input_image( p=MockImg2ImgProcessing( init_images=[TestScript.sample_np_image], - resize_mode=ResizeMode.OUTER_FIT, + resize_mode=external_code.ResizeMode.OUTER_FIT, ), - unit=ControlNetUnit( + unit=external_code.ControlNetUnit( module="none", - resize_mode=ResizeMode.INNER_FIT, + resize_mode=external_code.ResizeMode.INNER_FIT, ), idx=0, ) - self.assertEqual(resize_mode, ResizeMode.OUTER_FIT) + self.assertEqual(resize_mode, external_code.ResizeMode.OUTER_FIT) if __name__ == "__main__": diff --git a/tests/cn_script/infotext_test.py b/tests/cn_script/infotext_test.py new file mode 100644 index 000000000..61a7002ee --- /dev/null +++ b/tests/cn_script/infotext_test.py @@ -0,0 +1,34 @@ +import unittest +import importlib + +utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils") + +from scripts.infotext import parse_unit +from scripts.external_code import ControlNetUnit + + +class TestInfotext(unittest.TestCase): + def test_parsing(self): + infotext = ( + "Module: inpaint_only+lama, Model: control_v11p_sd15_inpaint [ebff9138], Weight: 1, " + "Resize Mode: Resize and Fill, Low Vram: False, Guidance Start: 0, Guidance End: 1, " + "Pixel Perfect: True, Control Mode: Balanced, Hr Option: Both, Save Detected Map: True" + ) + self.assertEqual( + vars( + ControlNetUnit( + module="inpaint_only+lama", + model="control_v11p_sd15_inpaint [ebff9138]", + weight=1, + resize_mode="Resize and Fill", + low_vram=False, + guidance_start=0, + guidance_end=1, + pixel_perfect=True, + control_mode="Balanced", + hr_option="Both", + save_detected_map=True, + ) + ), + vars(parse_unit(infotext)), + ) diff --git a/tests/external_code_api/external_code_test.py b/tests/external_code_api/external_code_test.py index e7c513088..b2b4101d2 100644 --- a/tests/external_code_api/external_code_test.py +++ b/tests/external_code_api/external_code_test.py @@ -9,8 +9,6 @@ from copy import copy from scripts import external_code from scripts import controlnet -from scripts.enums import ResizeMode -from internal_controlnet.external_code import ControlNetUnit from modules import scripts, ui, shared @@ -50,15 +48,79 @@ def test_empty_resizes_min_args(self): def test_empty_resizes_extra_args(self): extra_models = 1 - self.cn_units = [ControlNetUnit()] * (self.max_models + extra_models) + self.cn_units = [external_code.ControlNetUnit()] * (self.max_models + extra_models) self.assert_update_in_place_ok() +class TestControlNetUnitConversion(unittest.TestCase): + def setUp(self): + self.dummy_image = 'base64...' + self.input = {} + self.expected = external_code.ControlNetUnit() + + def assert_converts_to_expected(self): + self.assertEqual(vars(external_code.to_processing_unit(self.input)), vars(self.expected)) + + def test_empty_dict_works(self): + self.assert_converts_to_expected() + + def test_image_works(self): + self.input = { + 'image': self.dummy_image + } + self.expected = external_code.ControlNetUnit(image=self.dummy_image) + self.assert_converts_to_expected() + + def test_image_alias_works(self): + self.input = { + 'input_image': self.dummy_image + } + self.expected = external_code.ControlNetUnit(image=self.dummy_image) + self.assert_converts_to_expected() + + def test_masked_image_works(self): + self.input = { + 'image': self.dummy_image, + 'mask': self.dummy_image, + } + self.expected = external_code.ControlNetUnit(image={'image': self.dummy_image, 'mask': self.dummy_image}) + self.assert_converts_to_expected() + + +class TestControlNetUnitImageToDict(unittest.TestCase): + def setUp(self): + self.dummy_image = utils.readImage("test/test_files/img2img_basic.png") + self.input = external_code.ControlNetUnit() + self.expected_image = external_code.to_base64_nparray(self.dummy_image) + self.expected_mask = external_code.to_base64_nparray(self.dummy_image) + + def assert_dict_is_valid(self): + actual_dict = controlnet.image_dict_from_any(self.input.image) + self.assertEqual(actual_dict['image'].tolist(), self.expected_image.tolist()) + self.assertEqual(actual_dict['mask'].tolist(), self.expected_mask.tolist()) + + def test_none(self): + self.assertEqual(controlnet.image_dict_from_any(self.input.image), None) + + def test_image_without_mask(self): + self.input.image = self.dummy_image + self.expected_mask = np.zeros_like(self.expected_image, dtype=np.uint8) + self.assert_dict_is_valid() + + def test_masked_image_tuple(self): + self.input.image = (self.dummy_image, self.dummy_image,) + self.assert_dict_is_valid() + + def test_masked_image_dict(self): + self.input.image = {'image': self.dummy_image, 'mask': self.dummy_image} + self.assert_dict_is_valid() + + class TestPixelPerfectResolution(unittest.TestCase): def test_outer_fit(self): image = np.zeros((100, 100, 3)) target_H, target_W = 50, 100 - resize_mode = ResizeMode.OUTER_FIT + resize_mode = external_code.ResizeMode.OUTER_FIT result = external_code.pixel_perfect_resolution(image, target_H, target_W, resize_mode) expected = 50 # manually computed expected result self.assertEqual(result, expected) @@ -66,11 +128,43 @@ def test_outer_fit(self): def test_inner_fit(self): image = np.zeros((100, 100, 3)) target_H, target_W = 50, 100 - resize_mode = ResizeMode.INNER_FIT + resize_mode = external_code.ResizeMode.INNER_FIT result = external_code.pixel_perfect_resolution(image, target_H, target_W, resize_mode) expected = 100 # manually computed expected result self.assertEqual(result, expected) +class TestGetAllUnitsFrom(unittest.TestCase): + def test_none(self): + self.assertListEqual(external_code.get_all_units_from([None]), []) + + def test_bool(self): + self.assertListEqual(external_code.get_all_units_from([True]), []) + + def test_inheritance(self): + class Foo(external_code.ControlNetUnit): + def __init__(self): + super().__init__(self) + self.bar = 'a' + + foo = Foo() + self.assertListEqual(external_code.get_all_units_from([foo]), [foo]) + + def test_dict(self): + units = external_code.get_all_units_from([{}]) + self.assertGreater(len(units), 0) + self.assertIsInstance(units[0], external_code.ControlNetUnit) + + def test_unitlike(self): + class Foo(object): + """ bar """ + + foo = Foo() + for key in vars(external_code.ControlNetUnit()).keys(): + setattr(foo, key, True) + setattr(foo, 'bar', False) + self.assertListEqual(external_code.get_all_units_from([foo]), [foo]) + + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/external_code_api/importlib_reload_test.py b/tests/external_code_api/importlib_reload_test.py new file mode 100644 index 000000000..68ba7fb89 --- /dev/null +++ b/tests/external_code_api/importlib_reload_test.py @@ -0,0 +1,24 @@ +import unittest +import importlib +utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils') + + +from scripts import external_code + + +class TestImportlibReload(unittest.TestCase): + def setUp(self): + self.ControlNetUnit = external_code.ControlNetUnit + + def test_reload_does_not_redefine(self): + importlib.reload(external_code) + NewControlNetUnit = external_code.ControlNetUnit + self.assertEqual(self.ControlNetUnit, NewControlNetUnit) + + def test_force_import_does_not_redefine(self): + external_code_copy = importlib.import_module('extensions.sd-webui-controlnet.scripts.external_code', 'external_code') + self.assertEqual(self.ControlNetUnit, external_code_copy.ControlNetUnit) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/external_code_api/script_args_test.py b/tests/external_code_api/script_args_test.py new file mode 100644 index 000000000..99c710260 --- /dev/null +++ b/tests/external_code_api/script_args_test.py @@ -0,0 +1,34 @@ +import unittest +import importlib +utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils') + + +from scripts import external_code + + +class TestGetAllUnitsFrom(unittest.TestCase): + def setUp(self): + self.control_unit = { + "module": "none", + "model": utils.get_model("canny"), + "image": utils.readImage("test/test_files/img2img_basic.png"), + "resize_mode": 1, + "low_vram": False, + "processor_res": 64, + "control_mode": external_code.ControlMode.BALANCED.value, + } + self.object_unit = external_code.ControlNetUnit(**self.control_unit) + + def test_empty_converts(self): + script_args = [] + units = external_code.get_all_units_from(script_args) + self.assertListEqual(units, []) + + def test_object_forwards(self): + script_args = [self.object_unit] + units = external_code.get_all_units_from(script_args) + self.assertListEqual(units, [self.object_unit]) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/web_api/detect_test.py b/tests/web_api/detect_test.py index f4021aca3..dc37f33de 100644 --- a/tests/web_api/detect_test.py +++ b/tests/web_api/detect_test.py @@ -75,7 +75,7 @@ def detect_template(payload, output_name: str, status: int = 200): def test_detect_all_modules(module: str): payload = dict( controlnet_input_images=[realistic_girl_face_img], - controlnet_module=module, + controlnet_masks=[mask_img], ) detect_template(payload, f"detect_{module}") @@ -143,16 +143,16 @@ def test_detect_default_param(): dict( controlnet_input_images=[realistic_girl_face_img], controlnet_module="canny", # Canny does not require model download. - controlnet_threshold_a=-100, - controlnet_threshold_b=-100, - controlnet_processor_res=-100, + controlnet_threshold_a=-1, + controlnet_threshold_b=-1, + controlnet_processor_res=-1, ), "default_param", ) assert log_context.is_in_console_logs( [ - "[canny.processor_res] Invalid value(-100), using default value 512.", - "[canny.threshold_a] Invalid value(-100.0), using default value 100.", - "[canny.threshold_b] Invalid value(-100.0), using default value 200.", + "[canny.processor_res] Invalid value(-1), using default value 512.", + "[canny.threshold_a] Invalid value(-1.0), using default value 100.", + "[canny.threshold_b] Invalid value(-1.0), using default value 200.", ] ) diff --git a/tests/web_api/full_coverage/template.py b/tests/web_api/full_coverage/template.py index 333c1d3f7..ecccc74dc 100644 --- a/tests/web_api/full_coverage/template.py +++ b/tests/web_api/full_coverage/template.py @@ -169,16 +169,16 @@ def expect_same_image(img1, img2, diff_img_path: str) -> bool: default_unit = { - "control_mode": "Balanced", + "control_mode": 0, "enabled": True, "guidance_end": 1, "guidance_start": 0, "low_vram": False, "pixel_perfect": True, "processor_res": 512, - "resize_mode": "Crop and Resize", - "threshold_a": -1, - "threshold_b": -1, + "resize_mode": 1, + "threshold_a": 64, + "threshold_b": 64, "weight": 1, } diff --git a/tests/web_api/generation_test.py b/tests/web_api/generation_test.py index 8da3d1c4e..0dd39357d 100644 --- a/tests/web_api/generation_test.py +++ b/tests/web_api/generation_test.py @@ -87,13 +87,12 @@ def test_invalid_param(gen_type, param_name): f"test_invalid_param{(gen_type, param_name)}", gen_type, payload_overrides={}, - unit_overrides={param_name: -100}, + unit_overrides={param_name: -1}, input_image=girl_img, ).exec() - number = "-100" if param_name == "processor_res" else "-100.0" assert log_context.is_in_console_logs( [ - f"[canny.{param_name}] Invalid value({number}), using default value", + f"[canny.{param_name}] Invalid value(-1), using default value", ] ) @@ -172,7 +171,7 @@ def test_reference(): "model": "None", }, input_image=girl_img, - ).exec(result_only=False) + ).exec() def test_advanced_weighting(): @@ -193,7 +192,7 @@ def test_hr_option(): "enable_hr": True, "denoising_strength": 0.75, }, - unit_overrides={"hr_option": "Both"}, + unit_overrides={"hr_option": "HiResFixOption.BOTH"}, input_image=girl_img, ).exec(expected_output_num=3) @@ -204,7 +203,7 @@ def test_hr_option_default(): "test_hr_option_default", "txt2img", payload_overrides={"enable_hr": False}, - unit_overrides={"hr_option": "Both"}, + unit_overrides={"hr_option": "HiResFixOption.BOTH"}, input_image=girl_img, ).exec(expected_output_num=2) diff --git a/tests/web_api/template.py b/tests/web_api/template.py index aa8423dab..2f3a56344 100644 --- a/tests/web_api/template.py +++ b/tests/web_api/template.py @@ -295,15 +295,15 @@ def get_model(model_name: str) -> str: default_unit = { - "control_mode": "Balanced", + "control_mode": 0, "enabled": True, "guidance_end": 1, "guidance_start": 0, "pixel_perfect": True, "processor_res": 512, - "resize_mode": "Crop and Resize", - "threshold_a": -1, - "threshold_b": -1, + "resize_mode": 1, + "threshold_a": 64, + "threshold_b": 64, "weight": 1, "module": "canny", "model": get_model("sd15_canny"), diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/unit_tests/args_test.py b/unit_tests/args_test.py deleted file mode 100644 index de9158220..000000000 --- a/unit_tests/args_test.py +++ /dev/null @@ -1,241 +0,0 @@ -import pytest -import torch -import numpy as np -from dataclasses import dataclass - -from internal_controlnet.args import ControlNetUnit - -H = W = 128 - -img1 = np.ones(shape=[H, W, 3], dtype=np.uint8) -img2 = np.ones(shape=[H, W, 3], dtype=np.uint8) * 2 -mask_diff = np.ones(shape=[H - 1, W - 1, 3], dtype=np.uint8) * 2 -mask_2d = np.ones(shape=[H, W]) -img_bad_channel = np.ones(shape=[H, W, 2], dtype=np.uint8) * 2 -img_bad_dim = np.ones(shape=[1, H, W, 3], dtype=np.uint8) * 2 -ui_img_diff = np.ones(shape=[H - 1, W - 1, 4], dtype=np.uint8) * 2 -ui_img = np.ones(shape=[H, W, 4], dtype=np.uint8) -tensor1 = torch.zeros(size=[1, 1], dtype=torch.float16) - - -@pytest.fixture(scope="module") -def set_cls_funcs(): - ControlNetUnit.cls_match_model = lambda s: s in { - "None", - "model1", - "model2", - "control_v11p_sd15_inpaint [ebff9138]", - } - ControlNetUnit.cls_match_module = lambda s: s in { - "none", - "module1", - "inpaint_only+lama", - } - ControlNetUnit.cls_decode_base64 = lambda s: { - "b64img1": img1, - "b64img2": img2, - "b64mask_diff": mask_diff, - }[s] - ControlNetUnit.cls_torch_load_base64 = lambda s: { - "b64tensor1": tensor1, - }[s] - ControlNetUnit.cls_get_preprocessor = lambda s: { - "module1": MockPreprocessor(), - "none": MockPreprocessor(), - "inpaint_only+lama": MockPreprocessor(), - }[s] - - -def test_module_invalid(set_cls_funcs): - with pytest.raises(ValueError) as excinfo: - ControlNetUnit(module="foo") - - assert "module(foo) not found in supported modules." in str(excinfo.value) - - -def test_module_valid(set_cls_funcs): - ControlNetUnit(module="module1") - - -def test_model_invalid(set_cls_funcs): - with pytest.raises(ValueError) as excinfo: - ControlNetUnit(model="foo") - - assert "model(foo) not found in supported models." in str(excinfo.value) - - -def test_model_valid(set_cls_funcs): - ControlNetUnit(model="model1") - - -@pytest.mark.parametrize( - "d", - [ - # API - dict(image={"image": "b64img1"}), - dict(image={"image": "b64img1", "mask": "b64img2"}), - dict(image=["b64img1", "b64img2"]), - dict(image=("b64img1", "b64img2")), - dict(image=[{"image": "b64img1", "mask": "b64img2"}]), - dict(image=[{"image": "b64img1"}]), - dict(image=[{"image": "b64img1", "mask": None}]), - dict( - image=[ - {"image": "b64img1", "mask": "b64img2"}, - {"image": "b64img1", "mask": "b64img2"}, - ] - ), - dict( - image=[ - {"image": "b64img1", "mask": None}, - {"image": "b64img1", "mask": "b64img2"}, - ] - ), - dict( - image=[ - {"image": "b64img1"}, - {"image": "b64img1", "mask": "b64img2"}, - ] - ), - dict(image="b64img1", mask="b64img2"), - dict(image="b64img1"), - dict(image="b64img1", mask_image="b64img2"), - dict(image=None), - # UI - dict(image=dict(image=img1)), - dict(image=dict(image=img1, mask=img2)), - # 2D mask should be accepted. - dict(image=dict(image=img1, mask=mask_2d)), - dict(image=img1, mask=mask_2d), - ], -) -def test_valid_image_formats(set_cls_funcs, d): - ControlNetUnit(**d) - unit = ControlNetUnit.from_dict(d) - unit.get_input_images_rgba() - - -@pytest.mark.parametrize( - "d", - [ - dict(image={"mask": "b64img1"}), - dict(image={"foo": "b64img1", "bar": "b64img2"}), - dict(image=["b64img1"]), - dict(image=("b64img1", "b64img2", "b64img1")), - dict(image=[]), - dict(image=[{"mask": "b64img1"}]), - dict(image=None, mask="b64img2"), - # image & mask have different H x W - dict(image="b64img1", mask="b64mask_diff"), - ], -) -def test_invalid_image_formats(set_cls_funcs, d): - # Setting field will be fine. - ControlNetUnit(**d) - unit = ControlNetUnit.from_dict(d) - # Error on eval. - with pytest.raises((ValueError, AssertionError)): - unit.get_input_images_rgba() - - -def test_mask_alias_conflict(): - with pytest.raises((ValueError, AssertionError)): - ControlNetUnit.from_dict( - dict( - image="b64img1", - mask="b64img1", - mask_image="b64img1", - ) - ), - - -def test_resize_mode(): - ControlNetUnit(resize_mode="Just Resize") - - -def test_weight(): - ControlNetUnit(weight=0.5) - ControlNetUnit(weight=0.0) - with pytest.raises(ValueError): - ControlNetUnit(weight=-1) - with pytest.raises(ValueError): - ControlNetUnit(weight=100) - - -def test_start_end(): - ControlNetUnit(guidance_start=0.0, guidance_end=1.0) - ControlNetUnit(guidance_start=0.5, guidance_end=1.0) - ControlNetUnit(guidance_start=0.5, guidance_end=0.5) - - with pytest.raises(ValueError): - ControlNetUnit(guidance_start=1.0, guidance_end=0.0) - with pytest.raises(ValueError): - ControlNetUnit(guidance_start=11) - with pytest.raises(ValueError): - ControlNetUnit(guidance_end=11) - - -def test_effective_region_mask(): - ControlNetUnit(effective_region_mask="b64img1") - ControlNetUnit(effective_region_mask=None) - ControlNetUnit(effective_region_mask=img1) - - with pytest.raises(ValueError): - ControlNetUnit(effective_region_mask=124) - - -def test_ipadapter_input(): - ControlNetUnit(ipadapter_input=["b64tensor1"]) - ControlNetUnit(ipadapter_input="b64tensor1") - ControlNetUnit(ipadapter_input=None) - - with pytest.raises(ValueError): - ControlNetUnit(ipadapter_input=[]) - - -@dataclass -class MockSlider: - value: float = 1 - minimum: float = 0 - maximum: float = 2 - - -@dataclass -class MockPreprocessor: - slider_resolution = MockSlider() - slider_1 = MockSlider() - slider_2 = MockSlider() - - -def test_preprocessor_sliders(): - unit = ControlNetUnit(enabled=True, module="none") - assert unit.processor_res == 1 - assert unit.threshold_a == 1 - assert unit.threshold_b == 1 - - -def test_preprocessor_sliders_disabled(): - unit = ControlNetUnit(enabled=False, module="none") - assert unit.processor_res == -1 - assert unit.threshold_a == -1 - assert unit.threshold_b == -1 - - -def test_infotext_parsing(): - infotext = ( - "Module: inpaint_only+lama, Model: control_v11p_sd15_inpaint [ebff9138], Weight: 1, " - "Resize Mode: Resize and Fill, Low Vram: False, Guidance Start: 0, Guidance End: 1, " - "Pixel Perfect: True, Control Mode: Balanced" - ) - assert ControlNetUnit( - enabled=True, - module="inpaint_only+lama", - model="control_v11p_sd15_inpaint [ebff9138]", - weight=1, - resize_mode="Resize and Fill", - low_vram=False, - guidance_start=0, - guidance_end=1, - pixel_perfect=True, - control_mode="Balanced", - ) == ControlNetUnit.parse(infotext)