diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 29b31e288..9a9ecacaa 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -94,6 +94,10 @@ jobs: wait-for-it --service 127.0.0.1:7860 -t 600 python -m pytest -v --junitxml=test/results.xml --cov ./extensions/sd-webui-controlnet --cov-report=xml --verify-base-url ./extensions/sd-webui-controlnet/tests working-directory: stable-diffusion-webui + - name: Run unit tests + run: | + python -m pytest -v ./unit_tests/ + working-directory: stable-diffusion-webui/extensions/sd-webui-controlnet/ - name: Kill test server if: always() run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10 diff --git a/README.md b/README.md index c4dda14c2..5a226a508 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ # ControlNet for Stable Diffusion WebUI The WebUI extension for ControlNet and other injection-based SD controls. - -![image](https://github.com/Mikubill/sd-webui-controlnet/assets/20929282/51172d20-606b-4b9f-aba5-db2f2417cb0b) +![image](https://github.com/Mikubill/sd-webui-controlnet/assets/20929282/261f9a50-ba9c-472f-b398-fced61929c4a) This extension is for AUTOMATIC1111's [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui), allows the Web UI to add [ControlNet](https://github.com/lllyasviel/ControlNet) to the original Stable Diffusion model to generate images. The addition is on-the-fly, the merging is not required. # News + +- [2024-05-04] 🔥[v1.1.447] PuLID [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2841] +- [2024-04-30] 🔥[v1.1.446] Effective region mask supported for ControlNet/IPAdapter [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2831] - [2024-04-27] 🔥ControlNet-lllite Normal Dsine released [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2813] - [2024-04-19] 🔥[v1.1.445] IPAdapter advanced weight [Instant Style] [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2770] - [2024-04-17] 🔥[v1.1.444] Marigold depth preprocessor [Discussion thread: https://github.com/Mikubill/sd-webui-controlnet/discussions/2760] diff --git a/internal_controlnet/__init__.py b/internal_controlnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/internal_controlnet/args.py b/internal_controlnet/args.py new file mode 100644 index 000000000..50e32802e --- /dev/null +++ b/internal_controlnet/args.py @@ -0,0 +1,443 @@ +from __future__ import annotations +import os +import torch +import numpy as np +from typing import Optional, List, Annotated, ClassVar, Callable, Any, Tuple, Union +from pydantic import BaseModel, validator, root_validator, Field +from PIL import Image +from logging import Logger +from copy import copy +from enum import Enum + +from scripts.enums import ( + InputMode, + ResizeMode, + ControlMode, + HiResFixOption, + PuLIDMode, +) + + +def _unimplemented_func(*args, **kwargs): + raise NotImplementedError("Not implemented.") + + +def field_to_displaytext(fieldname: str) -> str: + return " ".join([word.capitalize() for word in fieldname.split("_")]) + + +def displaytext_to_field(text: str) -> str: + return "_".join([word.lower() for word in text.split(" ")]) + + +def serialize_value(value) -> str: + if isinstance(value, Enum): + return value.value + return str(value) + + +def parse_value(value: str) -> Union[str, float, int, bool]: + if value in ("True", "False"): + return value == "True" + try: + return int(value) + except ValueError: + try: + return float(value) + except ValueError: + return value # Plain string. + + +class ControlNetUnit(BaseModel): + """ + Represents an entire ControlNet processing unit. + """ + + class Config: + arbitrary_types_allowed = True + extra = "ignore" + + cls_match_module: ClassVar[Callable[[str], bool]] = _unimplemented_func + cls_match_model: ClassVar[Callable[[str], bool]] = _unimplemented_func + cls_decode_base64: ClassVar[Callable[[str], np.ndarray]] = _unimplemented_func + cls_torch_load_base64: ClassVar[Callable[[Any], torch.Tensor]] = _unimplemented_func + cls_get_preprocessor: ClassVar[Callable[[str], Any]] = _unimplemented_func + cls_logger: ClassVar[Logger] = Logger("ControlNetUnit") + + # UI only fields. + is_ui: bool = False + input_mode: InputMode = InputMode.SIMPLE + batch_images: Optional[Any] = None + output_dir: str = "" + loopback: bool = False + + # General fields. + enabled: bool = False + module: str = "none" + + @validator("module", always=True, pre=True) + def check_module(cls, value: str) -> str: + if not ControlNetUnit.cls_match_module(value): + raise ValueError(f"module({value}) not found in supported modules.") + return value + + model: str = "None" + + @validator("model", always=True, pre=True) + def check_model(cls, value: str) -> str: + if not ControlNetUnit.cls_match_model(value): + raise ValueError(f"model({value}) not found in supported models.") + return value + + weight: Annotated[float, Field(ge=0.0, le=2.0)] = 1.0 + + # The image to be used for this ControlNetUnit. + image: Optional[Any] = None + + resize_mode: ResizeMode = ResizeMode.INNER_FIT + low_vram: bool = False + processor_res: int = -1 + threshold_a: float = -1 + threshold_b: float = -1 + + @root_validator + def bound_check_params(cls, values: dict) -> dict: + """ + Checks and corrects negative parameters in ControlNetUnit 'unit' in place. + Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to + their default values if negative. + """ + enabled = values.get("enabled") + if not enabled: + return values + + module = values.get("module") + if not module: + return values + + preprocessor = cls.cls_get_preprocessor(module) + assert preprocessor is not None + for unit_param, param in zip( + ("processor_res", "threshold_a", "threshold_b"), + ("slider_resolution", "slider_1", "slider_2"), + ): + value = values.get(unit_param) + cfg = getattr(preprocessor, param) + if value < cfg.minimum or value > cfg.maximum: + values[unit_param] = cfg.value + # Only report warning when non-default value is used. + if value != -1: + cls.cls_logger.info( + f"[{module}.{unit_param}] Invalid value({value}), using default value {cfg.value}." + ) + return values + + guidance_start: Annotated[float, Field(ge=0.0, le=1.0)] = 0.0 + guidance_end: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0 + + @root_validator + def guidance_check(cls, values: dict) -> dict: + start = values.get("guidance_start") + end = values.get("guidance_end") + if start > end: + raise ValueError(f"guidance_start({start}) > guidance_end({end})") + return values + + pixel_perfect: bool = False + control_mode: ControlMode = ControlMode.BALANCED + # Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area` + # in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`. + inpaint_crop_input_image: bool = True + # If hires fix is enabled in A1111, how should this ControlNet unit be applied. + # The value is ignored if the generation is not using hires fix. + hr_option: HiResFixOption = HiResFixOption.BOTH + + # Whether save the detected map of this unit. Setting this option to False prevents saving the + # detected map or sending detected map along with generated images via API. + # Currently the option is only accessible in API calls. + save_detected_map: bool = True + + # Weight for each layer of ControlNet params. + # For ControlNet: + # - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block) + # - SDXL: 10 weights (3 encoder block * 3 + 1 middle block) + # For T2IAdapter + # - SD1.5: 5 weights (4 encoder block + 1 middle block) + # - SDXL: 4 weights (3 encoder block + 1 middle block) + # For IPAdapter + # - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block) + # - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block) + # Note1: Setting advanced weighting will disable `soft_injection`, i.e. + # It is recommended to set ControlMode = BALANCED when using `advanced_weighting`. + # Note2: The field `weight` is still used in some places, e.g. reference_only, + # even advanced_weighting is set. + advanced_weighting: Optional[List[float]] = None + + # The effective region mask that unit's effect should be restricted to. + effective_region_mask: Optional[np.ndarray] = None + + @validator("effective_region_mask", pre=True) + def parse_effective_region_mask(cls, value) -> np.ndarray: + if isinstance(value, str): + return cls.cls_decode_base64(value) + assert isinstance(value, np.ndarray) or value is None + return value + + # The weight mode for PuLID. + # https://github.com/ToTheBeginning/PuLID + pulid_mode: PuLIDMode = PuLIDMode.FIDELITY + + # ------- API only fields ------- + # The tensor input for ipadapter. When this field is set in the API, + # the base64string will be interpret by torch.load to reconstruct ipadapter + # preprocessor output. + # Currently the option is only accessible in API calls. + ipadapter_input: Optional[List[torch.Tensor]] = None + + @validator("ipadapter_input", pre=True) + def parse_ipadapter_input(cls, value) -> Optional[List[torch.Tensor]]: + if value is None: + return None + if isinstance(value, str): + value = [value] + result = [cls.cls_torch_load_base64(b) for b in value] + assert result, "input cannot be empty" + return result + + # The mask to be used on top of the image. + mask: Optional[Any] = None + + @property + def accepts_multiple_inputs(self) -> bool: + """This unit can accept multiple input images.""" + return self.module in ( + "ip-adapter-auto", + "ip-adapter_clip_sdxl", + "ip-adapter_clip_sdxl_plus_vith", + "ip-adapter_clip_sd15", + "ip-adapter_face_id", + "ip-adapter_face_id_plus", + "ip-adapter_pulid", + "instant_id_face_embedding", + ) + + @property + def is_animate_diff_batch(self) -> bool: + return getattr(self, "animatediff_batch", False) + + @property + def uses_clip(self) -> bool: + """Whether this unit uses clip preprocessor.""" + return any( + ( + ("ip-adapter" in self.module and "face_id" not in self.module), + self.module + in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"), + ) + ) + + @property + def is_inpaint(self) -> bool: + return "inpaint" in self.module + + def get_actual_preprocessor(self): + if self.module == "ip-adapter-auto": + return ControlNetUnit.cls_get_preprocessor( + self.module + ).get_preprocessor_by_model(self.model) + return ControlNetUnit.cls_get_preprocessor(self.module) + + @classmethod + def parse_image(cls, image) -> np.ndarray: + if isinstance(image, np.ndarray): + np_image = image + elif isinstance(image, str): + # Necessary for batch. + if os.path.exists(image): + np_image = np.array(Image.open(image)).astype("uint8") + else: + np_image = cls.cls_decode_base64(image) + else: + raise ValueError(f"Unrecognized image format {image}.") + + # [H, W] => [H, W, 3] + if np_image.ndim == 2: + np_image = np.stack([np_image, np_image, np_image], axis=-1) + assert np_image.ndim == 3 + assert np_image.shape[2] == 3 + return np_image + + @classmethod + def combine_image_and_mask( + cls, np_image: np.ndarray, np_mask: Optional[np.ndarray] = None + ) -> np.ndarray: + """RGB + Alpha(Optional) => RGBA""" + # TODO: Change protocol to use 255 as A channel value. + # Note: mask is by default zeros, as both inpaint and + # clip mask does extra work on masked area. + np_mask = (np.zeros_like(np_image) if np_mask is None else np_mask)[:, :, 0:1] + if np_image.shape[:2] != np_mask.shape[:2]: + raise ValueError( + f"image shape ({np_image.shape[:2]}) not aligned with mask shape ({np_mask.shape[:2]})" + ) + return np.concatenate([np_image, np_mask], axis=2) # [H, W, 4] + + @classmethod + def legacy_field_alias(cls, values: dict) -> dict: + ext_compat_keys = { + "guidance": "guidance_end", + "lowvram": "low_vram", + "input_image": "image", + } + for alias, key in ext_compat_keys.items(): + if alias in values: + assert key not in values, f"Conflict of field '{alias}' and '{key}'" + values[key] = alias + cls.cls_logger.warn( + f"Deprecated alias '{alias}' detected. This field will be removed on 2024-06-01" + f"Please use '{key}' instead." + ) + + return values + + @classmethod + def mask_alias(cls, values: dict) -> dict: + """ + Field "mask_image" is the alias of field "mask". + This is for compatibility with SD Forge API. + """ + mask_image = values.get("mask_image") + mask = values.get("mask") + if mask_image is not None: + if mask is not None: + raise ValueError("Cannot specify both 'mask' and 'mask_image'!") + values["mask"] = mask_image + return values + + def get_input_images_rgba(self) -> Optional[List[np.ndarray]]: + """ + RGBA images with potentially different size. + Why we cannot have [B, H, W, C=4] here is that calculation of final + resolution requires generation target's dimensions. + + Parse image with following formats. + API + - image = {"image": base64image, "mask": base64image,} + - image = [image, mask] + - image = (image, mask) + - image = [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...] + - image = base64image, mask = base64image + + UI: + - image = {"image": np_image, "mask": np_image,} + - image = np_image, mask = np_image + """ + init_image = self.image + init_mask = self.mask + + if init_image is None: + assert init_mask is None + return None + + if isinstance(init_image, (list, tuple)): + if not init_image: + raise ValueError(f"{init_image} is not a valid 'image' field value") + if isinstance(init_image[0], dict): + # [{"image": ..., "mask": ...}, {"image": ..., "mask": ...}, ...] + images = init_image + else: + assert len(init_image) == 2 + # [image, mask] + # (image, mask) + images = [ + { + "image": init_image[0], + "mask": init_image[1], + } + ] + elif isinstance(init_image, dict): + # {"image": ..., "mask": ...} + images = [init_image] + elif isinstance(init_image, (str, np.ndarray)): + # image = base64image, mask = base64image + images = [ + { + "image": init_image, + "mask": init_mask, + } + ] + else: + raise ValueError(f"Unrecognized image field {init_image}") + + np_images = [] + for image_dict in images: + assert isinstance(image_dict, dict) + image = image_dict.get("image") + mask = image_dict.get("mask") + assert image is not None + + np_image = self.parse_image(image) + np_mask = self.parse_image(mask) if mask is not None else None + np_images.append(self.combine_image_and_mask(np_image, np_mask)) # [H, W, 4] + + return np_images + + @classmethod + def from_dict(cls, values: dict) -> ControlNetUnit: + values = copy(values) + values = cls.legacy_field_alias(values) + values = cls.mask_alias(values) + return ControlNetUnit(**values) + + @classmethod + def from_infotext_args(cls, *args) -> ControlNetUnit: + assert len(args) == len(ControlNetUnit.infotext_fields()) + return cls.from_dict( + {k: v for k, v in zip(ControlNetUnit.infotext_fields(), args)} + ) + + @staticmethod + def infotext_fields() -> Tuple[str]: + """Fields that should be included in infotext. + You should define a Gradio element with exact same name in ControlNetUiGroup + as well, so that infotext can wire the value to correct field when pasting + infotext. + """ + return ( + "module", + "model", + "weight", + "resize_mode", + "processor_res", + "threshold_a", + "threshold_b", + "guidance_start", + "guidance_end", + "pixel_perfect", + "control_mode", + ) + + def serialize(self) -> str: + """Serialize the unit for infotext.""" + infotext_dict = { + field_to_displaytext(field): serialize_value(getattr(self, field)) + for field in ControlNetUnit.infotext_fields() + } + if not all( + "," not in str(v) and ":" not in str(v) for v in infotext_dict.values() + ): + self.cls_logger.error(f"Unexpected tokens encountered:\n{infotext_dict}") + return "" + + return ", ".join(f"{field}: {value}" for field, value in infotext_dict.items()) + + @classmethod + def parse(cls, text: str) -> ControlNetUnit: + return ControlNetUnit( + enabled=True, + **{ + displaytext_to_field(key): parse_value(value) + for item in text.split(",") + for (key, value) in (item.strip().split(": "),) + }, + ) diff --git a/internal_controlnet/external_code.py b/internal_controlnet/external_code.py index 016142e5a..157f08efc 100644 --- a/internal_controlnet/external_code.py +++ b/internal_controlnet/external_code.py @@ -1,58 +1,30 @@ -import base64 -import io -from dataclasses import dataclass -from enum import Enum from copy import copy from typing import List, Any, Optional, Union, Tuple, Dict -import torch import numpy as np from modules import scripts, processing, shared -from modules.safe import unsafe_torch_load +from modules.api import api +from .args import ControlNetUnit from scripts import global_state from scripts.logging import logger -from scripts.enums import HiResFixOption -from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter +from scripts.enums import ( + ResizeMode, + BatchOption, # noqa: F401 + ControlMode, # noqa: F401 +) +from scripts.supported_preprocessor import ( + Preprocessor, + PreprocessorParameter, # noqa: F401 +) -from modules.api import api +import torch +import base64 +import io +from modules.safe import unsafe_torch_load def get_api_version() -> int: - return 2 - - -class ControlMode(Enum): - """ - The improved guess mode. - """ - - BALANCED = "Balanced" - PROMPT = "My prompt is more important" - CONTROL = "ControlNet is more important" - - -class BatchOption(Enum): - DEFAULT = "All ControlNet units for all images in a batch" - SEPARATE = "Each ControlNet unit for each image in a batch" - - -class ResizeMode(Enum): - """ - Resize modes for ControlNet input images. - """ - - RESIZE = "Just Resize" - INNER_FIT = "Crop and Resize" - OUTER_FIT = "Resize and Fill" - - def int_value(self): - if self == ResizeMode.RESIZE: - return 0 - elif self == ResizeMode.INNER_FIT: - return 1 - elif self == ResizeMode.OUTER_FIT: - return 2 - assert False, "NOTREACHED" + return 3 resize_mode_aliases = { @@ -82,15 +54,6 @@ def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode: return value -def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode: - if isinstance(value, str): - return ControlMode(value) - elif isinstance(value, int): - return [e for e in ControlMode][value] - else: - return value - - def visualize_inpaint_mask(img): if img.ndim == 3 and img.shape[2] == 4: result = img.copy() @@ -152,146 +115,7 @@ def pixel_perfect_resolution( return int(np.round(estimation)) -InputImage = Union[np.ndarray, str] -InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage] - - -@dataclass -class ControlNetUnit: - """ - Represents an entire ControlNet processing unit. - """ - - enabled: bool = True - module: str = "none" - model: str = "None" - weight: float = 1.0 - image: Optional[Union[InputImage, List[InputImage]]] = None - resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT - low_vram: bool = False - processor_res: int = -1 - threshold_a: float = -1 - threshold_b: float = -1 - guidance_start: float = 0.0 - guidance_end: float = 1.0 - pixel_perfect: bool = False - control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED - # Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area` - # in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`. - inpaint_crop_input_image: bool = True - # If hires fix is enabled in A1111, how should this ControlNet unit be applied. - # The value is ignored if the generation is not using hires fix. - hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH - - # Whether save the detected map of this unit. Setting this option to False prevents saving the - # detected map or sending detected map along with generated images via API. - # Currently the option is only accessible in API calls. - save_detected_map: bool = True - - # Weight for each layer of ControlNet params. - # For ControlNet: - # - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block) - # - SDXL: 10 weights (3 encoder block * 3 + 1 middle block) - # For T2IAdapter - # - SD1.5: 5 weights (4 encoder block + 1 middle block) - # - SDXL: 4 weights (3 encoder block + 1 middle block) - # For IPAdapter - # - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block) - # - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block) - # Note1: Setting advanced weighting will disable `soft_injection`, i.e. - # It is recommended to set ControlMode = BALANCED when using `advanced_weighting`. - # Note2: The field `weight` is still used in some places, e.g. reference_only, - # even advanced_weighting is set. - advanced_weighting: Optional[List[float]] = None - - # The effective region mask that unit's effect should be restricted to. - effective_region_mask: Optional[np.ndarray] = None - - # The tensor input for ipadapter. When this field is set in the API, - # the base64string will be interpret by torch.load to reconstruct ipadapter - # preprocessor output. - # Currently the option is only accessible in API calls. - ipadapter_input: Optional[List[Any]] = None - - def __eq__(self, other): - if not isinstance(other, ControlNetUnit): - return False - - return vars(self) == vars(other) - - def accepts_multiple_inputs(self) -> bool: - """This unit can accept multiple input images.""" - return self.module in ( - "ip-adapter_clip_sdxl", - "ip-adapter_clip_sdxl_plus_vith", - "ip-adapter_clip_sd15", - "ip-adapter_face_id", - "ip-adapter_face_id_plus", - "instant_id_face_embedding", - ) - - @staticmethod - def infotext_excluded_fields() -> List[str]: - return [ - "image", - "enabled", - # API-only fields. - "advanced_weighting", - "ipadapter_input", - # End of API-only fields. - # Note: "inpaint_crop_image" is img2img inpaint only flag, which does not - # provide much information when restoring the unit. - "inpaint_crop_input_image", - "effective_region_mask", - ] - - @property - def is_animate_diff_batch(self) -> bool: - return getattr(self, "animatediff_batch", False) - - @property - def uses_clip(self) -> bool: - """Whether this unit uses clip preprocessor.""" - return any( - ( - ("ip-adapter" in self.module and "face_id" not in self.module), - self.module - in ("clip_vision", "revision_clipvision", "revision_ignore_prompt"), - ) - ) - - @property - def is_inpaint(self) -> bool: - return "inpaint" in self.module - - def bound_check_params(self) -> None: - """ - Checks and corrects negative parameters in ControlNetUnit 'unit' in place. - Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to - their default values if negative. - """ - preprocessor = Preprocessor.get_preprocessor(self.module) - for unit_param, param in zip( - ("processor_res", "threshold_a", "threshold_b"), - ("slider_resolution", "slider_1", "slider_2"), - ): - value = getattr(self, unit_param) - cfg: PreprocessorParameter = getattr(preprocessor, param) - if value < 0: - setattr(self, unit_param, cfg.value) - logger.info( - f"[{self.module}.{unit_param}] Invalid value({value}), using default value {cfg.value}." - ) - - def get_actual_preprocessor(self) -> Preprocessor: - if self.module == "ip-adapter-auto": - return Preprocessor.get_preprocessor(self.module).get_preprocessor_by_model( - self.model - ) - return Preprocessor.get_preprocessor(self.module) - - -def to_base64_nparray(encoding: str): +def to_base64_nparray(encoding: str) -> np.ndarray: """ Convert a base64 image into the image type the extension uses """ @@ -396,73 +220,14 @@ def get_max_models_num(): return max_models_num -def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit: +def to_processing_unit(unit: Union[Dict, ControlNetUnit]) -> ControlNetUnit: """ Convert different types to processing unit. - If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details. """ - - ext_compat_keys = { - "guessmode": "guess_mode", - "guidance": "guidance_end", - "lowvram": "low_vram", - "input_image": "image", - } - if isinstance(unit, dict): - unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()} - - # Handle mask - mask = None - if "mask" in unit: - mask = unit["mask"] - del unit["mask"] - - if "mask_image" in unit: - mask = unit["mask_image"] - del unit["mask_image"] - - if "image" in unit and not isinstance(unit["image"], dict): - unit["image"] = ( - {"image": unit["image"], "mask": mask} - if mask is not None - else unit["image"] if unit["image"] else None - ) - - # Parse ipadapter_input - if "ipadapter_input" in unit and unit["ipadapter_input"] is not None: - - def decode_base64(b: str) -> torch.Tensor: - decoded_bytes = base64.b64decode(b) - return unsafe_torch_load(io.BytesIO(decoded_bytes)) - - if isinstance(unit["ipadapter_input"], str): - unit["ipadapter_input"] = [unit["ipadapter_input"]] + return ControlNetUnit.from_dict(unit) - unit["ipadapter_input"] = [ - decode_base64(b) for b in unit["ipadapter_input"] - ] - - if unit.get("effective_region_mask", None) is not None: - base64img = unit["effective_region_mask"] - assert isinstance(base64img, str) - unit["effective_region_mask"] = to_base64_nparray(base64img) - - if "guess_mode" in unit: - logger.warning( - "Guess Mode is removed since 1.1.136. Please use Control Mode instead." - ) - - for k in unit.keys(): - if k not in vars(ControlNetUnit): - logger.warn(f"Received unrecognized key '{k}' in API.") - - unit = ControlNetUnit( - **{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()} - ) - - # temporary, check #602 - # assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]' + assert isinstance(unit, ControlNetUnit) return unit @@ -649,3 +414,23 @@ def is_cn_script(script: scripts.Script) -> bool: """ return script.title().lower() == "controlnet" + + +# TODO: Add model constraint +ControlNetUnit.cls_match_model = lambda model: True +ControlNetUnit.cls_match_module = ( + lambda module: Preprocessor.get_preprocessor(module) is not None +) +ControlNetUnit.cls_get_preprocessor = Preprocessor.get_preprocessor +ControlNetUnit.cls_decode_base64 = to_base64_nparray + + +def decode_base64(b: str) -> torch.Tensor: + decoded_bytes = base64.b64decode(b) + return unsafe_torch_load(io.BytesIO(decoded_bytes)) + + +ControlNetUnit.cls_torch_load_base64 = decode_base64 +ControlNetUnit.cls_logger = logger + +logger.debug("ControlNetUnit initialized") diff --git a/javascript/controlnet_unit.mjs b/javascript/controlnet_unit.mjs index 53668bb12..147ecebaa 100644 --- a/javascript/controlnet_unit.mjs +++ b/javascript/controlnet_unit.mjs @@ -75,7 +75,6 @@ export class ControlNetUnit { this.attachImageUploadListener(); this.attachImageStateChangeObserver(); this.attachA1111SendInfoObserver(); - this.attachPresetDropdownObserver(); } getTabNavButton() { @@ -269,24 +268,4 @@ export class ControlNetUnit { }); } } - - attachPresetDropdownObserver() { - const presetDropDown = this.tab.querySelector('.cnet-preset-dropdown'); - - new MutationObserver((mutationsList) => { - for (const mutation of mutationsList) { - if (mutation.removedNodes.length > 0) { - setTimeout(() => { - this.updateActiveState(); - this.updateActiveUnitCount(); - this.updateActiveControlType(); - }, 1000); - return; - } - } - }).observe(presetDropDown, { - childList: true, - subtree: true, - }); - } } diff --git a/requirements.txt b/requirements.txt index 5fb5e2e2b..f0072d277 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ addict yapf albumentations==1.4.3 matplotlib +facexlib diff --git a/scripts/api.py b/scripts/api.py index c2d348a3a..fb6df759e 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -137,12 +137,12 @@ async def detect( ) unit = ControlNetUnit( + enabled=True, module=preprocessor.label, processor_res=controlnet_processor_res, threshold_a=controlnet_threshold_a, threshold_b=controlnet_threshold_b, ) - unit.bound_check_params() tensors = [] images = [] @@ -179,7 +179,7 @@ def accept(self, json_dict: dict) -> None: low_vram=low_vram, ) if preprocessor.returns_image: - images.append(encode_to_base64(result.display_image)) + images.append(encode_to_base64(result.display_images[0])) else: tensors.append(encode_tensor_to_base64(result.value)) diff --git a/scripts/batch_hijack.py b/scripts/batch_hijack.py index fe001e610..8e72c9e3b 100644 --- a/scripts/batch_hijack.py +++ b/scripts/batch_hijack.py @@ -1,10 +1,10 @@ import os -from copy import copy from typing import Tuple, List from modules import img2img, processing, shared, script_callbacks from scripts import external_code from scripts.enums import InputMode +from scripts.logging import logger class BatchHijack: def __init__(self): @@ -194,7 +194,7 @@ def unhijack_function(module, name, new_name): def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[List[str]], str, List[str]]: units = external_code.get_all_units_in_processing(p) - units = [copy(unit) for unit in units if getattr(unit, 'enabled', False)] + units = [unit.copy() for unit in units if getattr(unit, 'enabled', False)] any_unit_is_batch = False output_dir = '' input_file_names = [] @@ -222,6 +222,8 @@ def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[ else: batches[i].append(unit.image) + if any_unit_is_batch: + logger.info(f"Batch enabled ({len(batches)})") return any_unit_is_batch, batches, output_dir, input_file_names diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 3845d537e..35b349ff2 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -4,10 +4,9 @@ import logging from collections import OrderedDict from copy import copy, deepcopy -from typing import Dict, Optional, Tuple, List, Union +from typing import Dict, Optional, Tuple, List import modules.scripts as scripts from modules import shared, devices, script_callbacks, processing, masking, images -from modules.api.api import decode_base64_to_image import gradio as gr import time @@ -16,14 +15,25 @@ # Register all preprocessors. import scripts.preprocessor as preprocessor_init # noqa from annotator.util import HWC3 +from internal_controlnet.external_code import ControlNetUnit from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils from scripts.controlnet_lora import bind_control_lora, unbind_control_lora from scripts.controlnet_lllite import clear_all_lllite from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter +from scripts.ipadapter.pulid_attn import PULID_SETTING_FIDELITY, PULID_SETTING_STYLE from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent from scripts.hook import ControlParams, UnetHook, HackedImageRNG -from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption -from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit +from scripts.enums import ( + ControlModelType, + InputMode, + StableDiffusionVersion, + HiResFixOption, + PuLIDMode, + ControlMode, + BatchOption, + ResizeMode, +) +from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup from scripts.controlnet_ui.photopea import Photopea from scripts.logging import logger from scripts.supported_preprocessor import Preprocessor @@ -89,44 +99,7 @@ def swap_img2img_pipeline(p: processing.StableDiffusionProcessingImg2Img): global_state.update_cn_models() - - -def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]: - if image is None: - return None - - if isinstance(image, (tuple, list)): - image = {'image': image[0], 'mask': image[1]} - elif not isinstance(image, dict): - image = {'image': image, 'mask': None} - else: # type(image) is dict - # copy to enable modifying the dict and prevent response serialization error - image = dict(image) - - if isinstance(image['image'], str): - if os.path.exists(image['image']): - image['image'] = np.array(Image.open(image['image'])).astype('uint8') - elif image['image']: - image['image'] = external_code.to_base64_nparray(image['image']) - else: - image['image'] = None - - # If there is no image, return image with None image and None mask - if image['image'] is None: - image['mask'] = None - return image - - if 'mask' not in image or image['mask'] is None: - image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) - elif isinstance(image['mask'], str): - if os.path.exists(image['mask']): - image['mask'] = np.array(Image.open(image['mask']).convert("RGB")).astype('uint8') - elif image['mask']: - image['mask'] = external_code.to_base64_nparray(image['mask']) - else: - image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) - - return image +logger.info(f"ControlNet {controlnet_version.version_flag}") def prepare_mask( @@ -219,7 +192,7 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor: def get_control( p: StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, idx: int, control_model_type: ControlModelType, preprocessor: Preprocessor, @@ -232,12 +205,12 @@ def get_control( h, w, hr_y, hr_x = Script.get_target_dimensions(p) input_image, resize_mode = Script.choose_input_image(p, unit, idx) if isinstance(input_image, list): - assert unit.accepts_multiple_inputs() or unit.is_animate_diff_batch + assert unit.accepts_multiple_inputs or unit.is_animate_diff_batch input_images = input_image else: # Following operations are only for single input image. input_image = Script.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode) input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy - if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT: + if unit.module == 'inpaint_only+lama' and resize_mode == ResizeMode.OUTER_FIT: # inpaint_only+lama is special and required outpaint fix _, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x) input_images = [input_image] @@ -279,6 +252,7 @@ def preprocess_input_image(input_image: np.ndarray): ) detected_map = result.value is_image = preprocessor.returns_image + # TODO: Refactor img control detection logic. if high_res_fix: if is_image: hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x) @@ -293,7 +267,8 @@ def preprocess_input_image(input_image: np.ndarray): store_detected_map(detected_map, unit.module) else: control = detected_map - store_detected_map(input_image, unit.module) + for image in result.display_images: + store_detected_map(image, unit.module) if control_model_type == ControlModelType.T2I_StyleAdapter: control = control['last_hidden_state'] @@ -327,11 +302,11 @@ def __init__(self) -> None: self.latest_network = None self.input_image = None self.latest_model_hash = "" - self.enabled_units: List[external_code.ControlNetUnit] = [] + self.enabled_units: List[ControlNetUnit] = [] self.detected_map = [] self.post_processors = [] self.noise_modifier = None - self.ui_batch_option_state = [external_code.BatchOption.DEFAULT.value, False] + self.ui_batch_option_state = [BatchOption.DEFAULT.value, False] batch_hijack.instance.process_batch_callbacks.append(self.batch_tab_process) batch_hijack.instance.process_batch_each_callbacks.append(self.batch_tab_process_each) batch_hijack.instance.postprocess_batch_each_callbacks.insert(0, self.batch_tab_postprocess_each) @@ -343,27 +318,14 @@ def title(self): def show(self, is_img2img): return scripts.AlwaysVisible - @staticmethod - def get_default_ui_unit(is_ui=True): - cls = UiControlNetUnit if is_ui else external_code.ControlNetUnit - return cls( - enabled=False, - module="none", - model="None" - ) - def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ControlNetUiGroup, gr.State]: - group = ControlNetUiGroup( - is_img2img, - Script.get_default_ui_unit(), - photopea, - ) + group = ControlNetUiGroup(is_img2img, photopea) return group, group.render(tabname, elem_id_tabname) def ui_batch_options(self, is_img2img: bool, elem_id_tabname: str): batch_option = gr.Radio( - choices=[e.value for e in external_code.BatchOption], - value=external_code.BatchOption.DEFAULT.value, + choices=[e.value for e in BatchOption], + value=BatchOption.DEFAULT.value, label="Batch Option", elem_id=f"{elem_id_tabname}_controlnet_batch_option_radio", elem_classes="controlnet_batch_option_radio", @@ -516,7 +478,7 @@ def get_element(obj, strict=False): return attribute_value if attribute_value is not None else default @staticmethod - def parse_remote_call(p, unit: external_code.ControlNetUnit, idx): + def parse_remote_call(p, unit: ControlNetUnit, idx): selector = Script.get_remote_call unit.enabled = selector(p, "control_net_enabled", unit.enabled, idx, strict=True) @@ -612,7 +574,7 @@ def high_quality_resize(x, size): return y - if resize_mode == external_code.ResizeMode.RESIZE: + if resize_mode == ResizeMode.RESIZE: detected_map = high_quality_resize(detected_map, (w, h)) detected_map = safe_numpy(detected_map) return get_pytorch_control(detected_map), detected_map @@ -625,7 +587,7 @@ def high_quality_resize(x, size): safeint = lambda x: int(np.round(x)) - if resize_mode == external_code.ResizeMode.OUTER_FIT: + if resize_mode == ResizeMode.OUTER_FIT: k = min(k0, k1) borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0) high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype) @@ -653,10 +615,31 @@ def high_quality_resize(x, size): @staticmethod def get_enabled_units(p): + def unfold_merged(unit: ControlNetUnit) -> List[ControlNetUnit]: + """Unfolds a merged unit to multiple units. Keeps the unit merged for + preprocessors that can accept multiple input images. + """ + if unit.input_mode != InputMode.MERGE: + return [unit] + + if unit.accepts_multiple_inputs: + unit.input_mode = InputMode.SIMPLE + return [unit] + + assert isinstance(unit.image, list) + result = [] + for image in unit.image: + u = unit.copy() + u.image = [image] + u.input_mode = InputMode.SIMPLE + u.weight = unit.weight / len(unit.image) + result.append(u) + return result + units = external_code.get_all_units_in_processing(p) if len(units) == 0: # fill a null group - remote_unit = Script.parse_remote_call(p, Script.get_default_ui_unit(), 0) + remote_unit = Script.parse_remote_call(p, ControlNetUnit(), 0) if remote_unit.enabled: units.append(remote_unit) @@ -665,11 +648,7 @@ def get_enabled_units(p): local_unit = Script.parse_remote_call(p, unit, idx) if not local_unit.enabled: continue - - if hasattr(local_unit, "unfold_merged"): - enabled_units.extend(local_unit.unfold_merged()) - else: - enabled_units.append(copy(local_unit)) + enabled_units.extend(unfold_merged(local_unit)) Infotext.write_infotext(enabled_units, p) return enabled_units @@ -677,49 +656,37 @@ def get_enabled_units(p): @staticmethod def choose_input_image( p: processing.StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, idx: int - ) -> Tuple[np.ndarray, external_code.ResizeMode]: + ) -> Tuple[np.ndarray, ResizeMode]: """ Choose input image from following sources with descending priority: - p.image_control: [Deprecated] Lagacy way to pass image to controlnet. - p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet. - - unit.image: ControlNet tab input image. - - p.init_images: A1111 img2img tab input image. + - unit.image: ControlNet unit input image. + - p.init_images: A1111 img2img input image. Returns: - The input image in ndarray form. - The resize mode. """ - def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: - unit_has_multiple_images = ( - isinstance(unit.image, list) and - len(unit.image) > 0 and - "image" in unit.image[0] - ) - if unit_has_multiple_images: - return [ - d - for img in unit.image - for d in (image_dict_from_any(img),) - if d is not None - ] - return image_dict_from_any(unit.image) - - def decode_image(img) -> np.ndarray: - """Need to check the image for API compatibility.""" - if isinstance(img, str): - return np.asarray(decode_base64_to_image(image['image'])) - else: - assert isinstance(img, np.ndarray) - return img + def from_rgba_to_input(img: np.ndarray) -> np.ndarray: + if ( + shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) or + (img[:, :, 3] <= 5).all() or + (img[:, :, 3] >= 250).all() + ): + # Take RGB + return img[:, :, :3] + logger.info("Canvas scribble mode. Using mask scribble as input.") + return HWC3(img[:, :, 3]) # 4 input image sources. p_image_control = getattr(p, "image_control", None) p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx) - image = parse_unit_image(unit) + image = unit.get_input_images_rgba() a1111_image = getattr(p, "init_images", [None])[0] - resize_mode = external_code.resize_mode_from_value(unit.resize_mode) + resize_mode = unit.resize_mode if batch_hijack.instance.is_batch and p_image_control is not None: logger.warning("Warn: Using legacy field 'p.image_control'.") @@ -732,42 +699,18 @@ def decode_image(img) -> np.ndarray: input_image = np.concatenate([color, alpha], axis=2) else: input_image = HWC3(np.asarray(p_input_image)) - elif image: - if isinstance(image, list): - # Add mask logic if later there is a processor that accepts mask - # on multiple inputs. - input_image = [HWC3(decode_image(img['image'])) for img in image] - if unit.is_animate_diff_batch and len(image) > 0 and 'mask' in image[0] and image[0]['mask'] is not None: - for idx in range(len(input_image)): - while len(image[idx]['mask'].shape) < 3: - image[idx]['mask'] = image[idx]['mask'][..., np.newaxis] - if unit.is_inpaint or unit.uses_clip: - color = HWC3(image[idx]["image"]) - alpha = image[idx]['mask'][:, :, 0:1] - input_image[idx] = np.concatenate([color, alpha], axis=2) + elif image is not None: + assert isinstance(image, list) + # Inpaint mask or CLIP mask. + if unit.is_inpaint or unit.uses_clip: + # RGBA + input_image = image else: - input_image = HWC3(decode_image(image['image'])) - if 'mask' in image and image['mask'] is not None: - while len(image['mask'].shape) < 3: - image['mask'] = image['mask'][..., np.newaxis] - if unit.is_inpaint or unit.uses_clip: - logger.info("using mask") - color = HWC3(image['image']) - alpha = image['mask'][:, :, 0:1] - input_image = np.concatenate([color, alpha], axis=2) - elif ( - not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and - # There is wield gradio issue that would produce mask that is - # not pure color when no scribble is made on canvas. - # See https://github.com/Mikubill/sd-webui-controlnet/issues/1638. - not ( - (image['mask'][:, :, 0] <= 5).all() or - (image['mask'][:, :, 0] >= 250).all() - ) - ): - logger.info("using mask as input") - input_image = HWC3(image['mask'][:, :, 0]) - unit.module = 'none' # Always use black bg and white line + # RGB + input_image = [from_rgba_to_input(img) for img in image] + + if len(input_image) == 1: + input_image = input_image[0] elif a1111_image is not None: input_image = HWC3(np.asarray(a1111_image)) a1111_i2i_resize_mode = getattr(p, "resize_mode", None) @@ -799,9 +742,9 @@ def decode_image(img) -> np.ndarray: @staticmethod def try_crop_image_with_a1111_mask( p: StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, input_image: np.ndarray, - resize_mode: external_code.ResizeMode, + resize_mode: ResizeMode, ) -> np.ndarray: """ Crop ControlNet input image based on A1111 inpaint mask given. @@ -843,7 +786,7 @@ def try_crop_image_with_a1111_mask( input_image = [x.crop(crop_region) for x in input_image] input_image = [ - images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height) + images.resize_image(ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height) for x in input_image ] @@ -852,7 +795,7 @@ def try_crop_image_with_a1111_mask( return input_image @staticmethod - def check_sd_version_compatible(unit: external_code.ControlNetUnit) -> None: + def check_sd_version_compatible(unit: ControlNetUnit) -> None: """ Checks whether the given ControlNet unit has model compatible with the currently active sd model. An exception is thrown if ControlNet unit is detected to be @@ -918,7 +861,7 @@ def controlnet_main_entry(self, p): if not batch_hijack.instance.is_batch: self.enabled_units = Script.get_enabled_units(p) - batch_option_uint_separate = self.ui_batch_option_state[0] == external_code.BatchOption.SEPARATE.value + batch_option_uint_separate = self.ui_batch_option_state[0] == BatchOption.SEPARATE.value batch_option_style_align = self.ui_batch_option_state[1] if len(self.enabled_units) == 0 and not batch_option_style_align: @@ -945,7 +888,6 @@ def controlnet_main_entry(self, p): high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) for idx, unit in enumerate(self.enabled_units): - unit.bound_check_params() Script.check_sd_version_compatible(unit) if ( 'inpaint_only' == unit.module and @@ -999,7 +941,7 @@ def controlnet_main_entry(self, p): elif unit.is_animate_diff_batch or control_model_type in [ControlModelType.SparseCtrl]: cn_ad_keyframe_idx = getattr(unit, "batch_keyframe_idx", None) def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe_idx): - if unit.accepts_multiple_inputs(): + if unit.accepts_multiple_inputs: ip_adapter_image_emb_cond = [] model_net.ipadapter.image_proj_model.to(torch.float32) # noqa for c in cc: @@ -1026,7 +968,7 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files): logger.info(f"\t{frame_idx}: {frame_path}") c = SparseCtrl.create_cond_mask(cn_ad_keyframe_idx, c, p.batch_size).cpu() - elif unit.accepts_multiple_inputs(): + elif unit.accepts_multiple_inputs: # ip-adapter should do prompt travel logger.info("IP-Adapter: control prompts will be traveled in the following way:") for frame_idx, frame_path in zip(unit.batch_keyframe_idx, unit.batch_image_files): @@ -1055,7 +997,7 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe c_full[cn_ad_keyframe_idx] = c c = c_full # handle batch condition and unconditional - if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs(): + if shared.opts.batch_cond_uncond and not unit.accepts_multiple_inputs: c = torch.cat([c, c], dim=0) return c @@ -1078,7 +1020,6 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe control_model_type.is_controlnet and model_net.control_model.global_average_pooling ) - control_mode = external_code.control_mode_from_value(unit.control_mode) forward_param = ControlParams( control_model=model_net, preprocessor=preprocessor_dict, @@ -1091,9 +1032,9 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe control_model_type=control_model_type, global_average_pooling=global_average_pooling, hr_hint_cond=hr_control, - hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH, - soft_injection=control_mode != external_code.ControlMode.BALANCED, - cfg_injection=control_mode == external_code.ControlMode.CONTROL, + hr_option=unit.hr_option if high_res_fix else HiResFixOption.BOTH, + soft_injection=unit.control_mode != ControlMode.BALANCED, + cfg_injection=unit.control_mode == ControlMode.CONTROL, effective_region_mask=( get_pytorch_control(unit.effective_region_mask)[:, 0:1, :, :] if unit.effective_region_mask is not None @@ -1190,7 +1131,7 @@ def recolor_intensity_post_processing(x, i): is_low_vram = any(unit.low_vram for unit in self.enabled_units) - for i, param in enumerate(forward_params): + for i, (param, unit) in enumerate(zip(forward_params, self.enabled_units)): if param.control_model_type == ControlModelType.IPAdapter: if param.advanced_weighting is not None: logger.info(f"IP-Adapter using advanced weighting {param.advanced_weighting}") @@ -1205,6 +1146,12 @@ def recolor_intensity_post_processing(x, i): weight = param.weight h, w, hr_y, hr_x = Script.get_target_dimensions(p) + if unit.pulid_mode == PuLIDMode.STYLE: + pulid_attn_setting = PULID_SETTING_STYLE + else: + assert unit.pulid_mode == PuLIDMode.FIDELITY + pulid_attn_setting = PULID_SETTING_FIDELITY + param.control_model.hook( model=unet, preprocessor_outputs=param.hint_cond, @@ -1215,6 +1162,7 @@ def recolor_intensity_post_processing(x, i): latent_width=w // 8, latent_height=h // 8, effective_region_mask=param.effective_region_mask, + pulid_attn_setting=pulid_attn_setting, ) if param.control_model_type == ControlModelType.Controlllite: param.control_model.hook( @@ -1352,7 +1300,7 @@ def batch_tab_process(self, p, batches, *args, **kwargs): unit.batch_images = iter([batch[unit_i] for batch in batches]) def batch_tab_process_each(self, p, *args, **kwargs): - for unit_i, unit in enumerate(self.enabled_units): + for unit in self.enabled_units: if getattr(unit, 'loopback', False) and batch_hijack.instance.batch_index > 0: continue diff --git a/scripts/controlnet_ui/controlnet_ui_group.py b/scripts/controlnet_ui/controlnet_ui_group.py index 2bfc18e22..908f4f042 100644 --- a/scripts/controlnet_ui/controlnet_ui_group.py +++ b/scripts/controlnet_ui/controlnet_ui_group.py @@ -1,8 +1,8 @@ import json import gradio as gr import functools -from copy import copy -from typing import List, Optional, Union, Dict, Tuple, Literal +import itertools +from typing import List, Optional, Union, Dict, Tuple, Literal, Any from dataclasses import dataclass import numpy as np @@ -13,12 +13,18 @@ external_code, ) from annotator.util import HWC3 +from internal_controlnet.external_code import ControlNetUnit from scripts.logging import logger from scripts.controlnet_ui.openpose_editor import OpenposeEditor -from scripts.controlnet_ui.preset import ControlNetPresetUI from scripts.controlnet_ui.photopea import Photopea from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl -from scripts.enums import InputMode +from scripts.enums import ( + InputMode, + HiResFixOption, + PuLIDMode, + ControlMode, + ResizeMode, +) from modules import shared from modules.ui_components import FormRow, FormHTML, ToolButton @@ -121,72 +127,39 @@ def set_component(self, component: gr.components.Component): ) -class UiControlNetUnit(external_code.ControlNetUnit): - """The data class that stores all states of a ControlNetUnit.""" - - def __init__( - self, - input_mode: InputMode = InputMode.SIMPLE, - batch_images: Optional[Union[str, List[external_code.InputImage]]] = None, - output_dir: str = "", - loopback: bool = False, - merge_gallery_files: List[ - Dict[Union[Literal["name"], Literal["data"]], str] - ] = [], - use_preview_as_input: bool = False, - generated_image: Optional[np.ndarray] = None, - mask_image: Optional[np.ndarray] = None, - enabled: bool = True, - module: Optional[str] = None, - model: Optional[str] = None, - weight: float = 1.0, - image: Optional[Dict[str, np.ndarray]] = None, - *args, - **kwargs, - ): - if use_preview_as_input and generated_image is not None: - input_image = generated_image - module = "none" - else: - input_image = image - - # Prefer uploaded mask_image over hand-drawn mask. - if input_image is not None and mask_image is not None: - assert isinstance(input_image, dict) - input_image["mask"] = mask_image - - if merge_gallery_files and input_mode == InputMode.MERGE: - input_image = [ - {"image": read_image(file["name"])} for file in merge_gallery_files - ] - - super().__init__(enabled, module, model, weight, input_image, *args, **kwargs) - self.is_ui = True - self.input_mode = input_mode - self.batch_images = batch_images - self.output_dir = output_dir - self.loopback = loopback +def create_ui_unit( + input_mode: InputMode = InputMode.SIMPLE, + batch_images: Optional[Any] = None, + output_dir: str = "", + loopback: bool = False, + merge_gallery_files: List[Dict[Union[Literal["name"], Literal["data"]], str]] = [], + use_preview_as_input: bool = False, + generated_image: Optional[np.ndarray] = None, + *args, +) -> ControlNetUnit: + unit_dict = { + k: v + for k, v in zip( + vars(ControlNetUnit()).keys(), + itertools.chain( + [True, input_mode, batch_images, output_dir, loopback], args + ), + ) + } - def unfold_merged(self) -> List[external_code.ControlNetUnit]: - """Unfolds a merged unit to multiple units. Keeps the unit merged for - preprocessors that can accept multiple input images. - """ - if self.input_mode != InputMode.MERGE: - return [copy(self)] + if use_preview_as_input and generated_image is not None: + input_image = generated_image + unit_dict["module"] = "none" + else: + input_image = unit_dict["image"] - if self.accepts_multiple_inputs(): - self.input_mode = InputMode.SIMPLE - return [copy(self)] + if merge_gallery_files and input_mode == InputMode.MERGE: + input_image = [ + {"image": read_image(file["name"])} for file in merge_gallery_files + ] - assert isinstance(self.image, list) - result = [] - for image in self.image: - unit = copy(self) - unit.image = image["image"] - unit.input_mode = InputMode.SIMPLE - unit.weight = self.weight / len(self.image) - result.append(unit) - return result + unit_dict["image"] = input_image + return ControlNetUnit.from_dict(unit_dict) class ControlNetUiGroup(object): @@ -220,7 +193,6 @@ class ControlNetUiGroup(object): def __init__( self, is_img2img: bool, - default_unit: external_code.ControlNetUnit, photopea: Optional[Photopea], ): # Whether callbacks have been registered. @@ -229,13 +201,13 @@ def __init__( self.ui_initialized: bool = False self.is_img2img = is_img2img - self.default_unit = default_unit + self.default_unit = ControlNetUnit() self.photopea = photopea self.webcam_enabled = False self.webcam_mirrored = False # Note: All gradio elements declared in `render` will be defined as member variable. - # Update counter to trigger a force update of UiControlNetUnit. + # Update counter to trigger a force update of ControlNetUnit. # This is useful when a field with no event subscriber available changes. # e.g. gr.Gallery, gr.State, etc. self.update_unit_counter = None @@ -244,7 +216,7 @@ def __init__( self.generated_image_group = None self.generated_image = None self.mask_image_group = None - self.mask_image = None + self.effective_region_mask = None self.batch_tab = None self.batch_image_dir = None self.merge_tab = None @@ -282,7 +254,6 @@ def __init__( self.loopback = None self.use_preview_as_input = None self.openpose_editor = None - self.preset_panel = None self.upload_independent_img_in_img2img = None self.image_upload_panel = None self.save_detected_map = None @@ -293,10 +264,10 @@ def __init__( self.batch_image_dir_state = None self.output_dir_state = None self.advanced_weighting = gr.State(None) + self.pulid_mode = None # API-only fields self.ipadapter_input = gr.State(None) - self.effective_region_mask = gr.Image(value=None, visible=False) ControlNetUiGroup.all_ui_groups.append(self) @@ -329,11 +300,13 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: tool="sketch", elem_id=f"{elem_id_tabname}_{tabname}_input_image", elem_classes=["cnet-image"], - brush_color=shared.opts.img2img_inpaint_mask_brush_color - if hasattr( - shared.opts, "img2img_inpaint_mask_brush_color" - ) - else None, + brush_color=( + shared.opts.img2img_inpaint_mask_brush_color + if hasattr( + shared.opts, "img2img_inpaint_mask_brush_color" + ) + else None + ), ) self.image.preprocess = functools.partial( svg_preprocess, preprocess=self.image.preprocess @@ -369,11 +342,11 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: with gr.Group( visible=False, elem_classes=["cnet-mask-image-group"] ) as self.mask_image_group: - self.mask_image = gr.Image( + self.effective_region_mask = gr.Image( value=None, - label="Upload Mask", + label="Effective Region Mask", elem_id=f"{elem_id_tabname}_{tabname}_mask_image", - elem_classes=["cnet-mask-image"], + elem_classes=["cnet-effective-region-mask-image"], interactive=True, ) @@ -481,11 +454,10 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: visible=not self.is_img2img, ) self.mask_upload = gr.Checkbox( - label="Mask Upload", + label="Effective Region Mask", value=False, elem_classes=["cnet-mask-upload"], elem_id=f"{elem_id_tabname}_{tabname}_controlnet_mask_upload_checkbox", - visible=not self.is_img2img, ) self.use_preview_as_input = gr.Checkbox( label="Preview as Input", @@ -515,7 +487,11 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: ) with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]): - self.type_filter = (gr.Dropdown if shared.opts.data.get("controlnet_control_type_dropdown", False) else gr.Radio)( + self.type_filter = ( + gr.Dropdown + if shared.opts.data.get("controlnet_control_type_dropdown", False) + else gr.Radio + )( Preprocessor.get_all_preprocessor_tags(), label="Control Type", value="All", @@ -609,7 +585,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: ) self.control_mode = gr.Radio( - choices=[e.value for e in external_code.ControlMode], + choices=[e.value for e in ControlMode], value=self.default_unit.control_mode.value, label="Control Mode", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_mode_radio", @@ -617,7 +593,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: ) self.resize_mode = gr.Radio( - choices=[e.value for e in external_code.ResizeMode], + choices=[e.value for e in ResizeMode], value=self.default_unit.resize_mode.value, label="Resize Mode", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_resize_mode_radio", @@ -626,7 +602,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: ) self.hr_option = gr.Radio( - choices=[e.value for e in external_code.HiResFixOption], + choices=[e.value for e in HiResFixOption], value=self.default_unit.hr_option.value, label="Hires-Fix Option", elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio", @@ -634,9 +610,18 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: visible=False, ) + self.pulid_mode = gr.Radio( + choices=[e.value for e in PuLIDMode], + value=self.default_unit.pulid_mode.value, + label="PuLID Mode", + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pulid_mode_radio", + elem_classes="controlnet_pulid_mode_radio", + visible=False, + ) + self.loopback = gr.Checkbox( label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation", - value=self.default_unit.loopback, + value=False, elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox", elem_classes="controlnet_loopback_checkbox", visible=False, @@ -644,10 +629,6 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.advanced_weight_control.render() - self.preset_panel = ControlNetPresetUI( - id_prefix=f"{elem_id_tabname}_{tabname}_" - ) - self.batch_image_dir_state = gr.State("") self.output_dir_state = gr.State("") unit_args = ( @@ -661,7 +642,6 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.merge_gallery, self.use_preview_as_input, self.generated_image, - self.mask_image, # End of Non-persistent fields. self.enabled, self.module, @@ -681,34 +661,17 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.hr_option, self.save_detected_map, self.advanced_weighting, + self.effective_region_mask, + self.pulid_mode, ) - unit = gr.State(self.default_unit) - for comp in unit_args + (self.update_unit_counter,): - event_subscribers = [] - if hasattr(comp, "edit"): - event_subscribers.append(comp.edit) - elif hasattr(comp, "click"): - event_subscribers.append(comp.click) - elif isinstance(comp, gr.Slider) and hasattr(comp, "release"): - event_subscribers.append(comp.release) - elif hasattr(comp, "change"): - event_subscribers.append(comp.change) - - if hasattr(comp, "clear"): - event_subscribers.append(comp.clear) - - for event_subscriber in event_subscribers: - event_subscriber( - fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit - ) - + unit = gr.State(ControlNetUnit()) ( ControlNetUiGroup.a1111_context.img2img_submit_button if self.is_img2img else ControlNetUiGroup.a1111_context.txt2img_submit_button ).click( - fn=UiControlNetUnit, + fn=create_ui_unit, inputs=list(unit_args), outputs=unit, queue=False, @@ -793,10 +756,12 @@ def refresh_all_models(model: str): def register_build_sliders(self): def build_sliders(module: str, pp: bool): preprocessor = Preprocessor.get_preprocessor(module) - slider_resolution_kwargs = preprocessor.slider_resolution.gradio_update_kwargs.copy() + slider_resolution_kwargs = ( + preprocessor.slider_resolution.gradio_update_kwargs.copy() + ) if pp: - slider_resolution_kwargs['visible'] = False + slider_resolution_kwargs["visible"] = False grs = [ gr.update(**slider_resolution_kwargs), @@ -842,9 +807,7 @@ def filter_selected(k: str): gr.Dropdown.update( value=default_option, choices=filtered_preprocessor_list ), - gr.Dropdown.update( - value=default_model, choices=filtered_model_list - ), + gr.Dropdown.update(value=default_model, choices=filtered_model_list), ] self.type_filter.change( @@ -883,7 +846,9 @@ def sd_version_changed(type_filter: str, current_model: str): ) def register_run_annotator(self): - def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str): + def run_annotator( + image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm, model: str + ): if image is None: return ( gr.update(value=None, visible=True), @@ -946,16 +911,16 @@ def is_openpose(module: str): and shared.opts.data.get("controlnet_clip_detector_on_cpu", False) ), json_pose_callback=( - json_acceptor.accept - if is_openpose(module) - else None + json_acceptor.accept if is_openpose(module) else None ), model=model, ) return ( # Update to `generated_image` - gr.update(value=result.display_image, visible=True, interactive=False), + gr.update( + value=result.display_images[0], visible=True, interactive=False + ), # preprocessor_preview gr.update(value=True), # openpose editor @@ -970,12 +935,16 @@ def is_openpose(module: str): self.processor_res, self.threshold_a, self.threshold_b, - ControlNetUiGroup.a1111_context.img2img_w_slider - if self.is_img2img - else ControlNetUiGroup.a1111_context.txt2img_w_slider, - ControlNetUiGroup.a1111_context.img2img_h_slider - if self.is_img2img - else ControlNetUiGroup.a1111_context.txt2img_h_slider, + ( + ControlNetUiGroup.a1111_context.img2img_w_slider + if self.is_img2img + else ControlNetUiGroup.a1111_context.txt2img_w_slider + ), + ( + ControlNetUiGroup.a1111_context.img2img_h_slider + if self.is_img2img + else ControlNetUiGroup.a1111_context.txt2img_h_slider + ), self.pixel_perfect, self.resize_mode, self.model, @@ -1122,22 +1091,17 @@ def register_shift_upload_mask(self): else (gr.update(visible=True), gr.update()) ), inputs=[self.mask_upload], - outputs=[self.mask_image_group, self.mask_image], + outputs=[self.mask_image_group, self.effective_region_mask], show_progress=False, ) - if self.upload_independent_img_in_img2img is not None: - self.upload_independent_img_in_img2img.change( - fn=lambda checked: ( - # Uncheck `upload_mask` when not using independent input. - gr.update(visible=False, value=False) - if not checked - else gr.update(visible=True) - ), - inputs=[self.upload_independent_img_in_img2img], - outputs=[self.mask_upload], - show_progress=False, - ) + def register_shift_pulid_mode(self): + self.model.change( + fn=lambda model: gr.update(visible="pulid" in model.lower()), + inputs=[self.model], + outputs=[self.pulid_mode], + show_progress=False, + ) def register_sync_batch_dir(self): def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir): @@ -1241,6 +1205,7 @@ def register_core_callbacks(self): self.register_build_sliders() self.register_shift_preview() self.register_shift_upload_mask() + self.register_shift_pulid_mode() self.register_create_canvas() self.register_clear_preview() self.register_multi_images_upload() @@ -1250,14 +1215,6 @@ def register_core_callbacks(self): self.model, ) assert self.type_filter is not None - self.preset_panel.register_callbacks( - self, - self.type_filter, - *[ - getattr(self, key) - for key in vars(external_code.ControlNetUnit()).keys() - ], - ) self.advanced_weight_control.register_callbacks( self.weight, self.advanced_weighting, diff --git a/scripts/controlnet_ui/preset.py b/scripts/controlnet_ui/preset.py deleted file mode 100644 index 3010d2617..000000000 --- a/scripts/controlnet_ui/preset.py +++ /dev/null @@ -1,305 +0,0 @@ -import os -import gradio as gr - -from typing import Dict, List - -from modules import scripts -from modules.ui_components import ToolButton -from scripts.infotext import parse_unit, serialize_unit -from scripts.logging import logger -from scripts import external_code -from scripts.supported_preprocessor import Preprocessor - -save_symbol = "\U0001f4be" # 💾 -delete_symbol = "\U0001f5d1\ufe0f" # 🗑️ -refresh_symbol = "\U0001f504" # 🔄 -reset_symbol = "\U000021A9" # ↩ - -NEW_PRESET = "New Preset" - - -def load_presets(preset_dir: str) -> Dict[str, str]: - if not os.path.exists(preset_dir): - os.makedirs(preset_dir) - return {} - - presets = {} - for filename in os.listdir(preset_dir): - if filename.endswith(".txt"): - with open(os.path.join(preset_dir, filename), "r") as f: - name = filename.replace(".txt", "") - if name == NEW_PRESET: - continue - presets[name] = f.read() - return presets - - -def infer_control_type(module: str, model: str) -> str: - p = Preprocessor.get_preprocessor(module) - assert p is not None - matched_tags = [ - tag - for tag in p.tags - if any(f in model.lower() for f in Preprocessor.tag_to_filters(tag)) - ] - if len(matched_tags) != 1: - raise ValueError( - f"Unable to infer control type from module {module} and model {model}" - ) - return matched_tags[0] - - -class ControlNetPresetUI(object): - preset_directory = os.path.join(scripts.basedir(), "presets") - presets = load_presets(preset_directory) - - def __init__(self, id_prefix: str): - with gr.Row(): - self.dropdown = gr.Dropdown( - label="Presets", - show_label=True, - elem_classes=["cnet-preset-dropdown"], - choices=ControlNetPresetUI.dropdown_choices(), - value=NEW_PRESET, - ) - self.reset_button = ToolButton( - value=reset_symbol, - elem_classes=["cnet-preset-reset"], - tooltip="Reset preset", - visible=False, - ) - self.save_button = ToolButton( - value=save_symbol, - elem_classes=["cnet-preset-save"], - tooltip="Save preset", - ) - self.delete_button = ToolButton( - value=delete_symbol, - elem_classes=["cnet-preset-delete"], - tooltip="Delete preset", - ) - self.refresh_button = ToolButton( - value=refresh_symbol, - elem_classes=["cnet-preset-refresh"], - tooltip="Refresh preset", - ) - - with gr.Box( - elem_classes=["popup-dialog", "cnet-preset-enter-name"], - elem_id=f"{id_prefix}_cnet_preset_enter_name", - ) as self.name_dialog: - with gr.Row(): - self.preset_name = gr.Textbox( - label="Preset name", - show_label=True, - lines=1, - elem_classes=["cnet-preset-name"], - ) - self.confirm_preset_name = ToolButton( - value=save_symbol, - elem_classes=["cnet-preset-confirm-name"], - tooltip="Save preset", - ) - - def register_callbacks( - self, - uigroup, - control_type: gr.Radio, - *ui_states, - ): - def apply_preset(name: str, control_type: str, *ui_states): - if name == NEW_PRESET: - return ( - gr.update(visible=False), - *( - (gr.skip(),) - * (len(vars(external_code.ControlNetUnit()).keys()) + 1) - ), - ) - - assert name in ControlNetPresetUI.presets - - infotext = ControlNetPresetUI.presets[name] - preset_unit = parse_unit(infotext) - current_unit = external_code.ControlNetUnit(*ui_states) - preset_unit.image = None - current_unit.image = None - - # Do not compare module param that are not used in preset. - for module_param in ("processor_res", "threshold_a", "threshold_b"): - if getattr(preset_unit, module_param) == -1: - setattr(current_unit, module_param, -1) - - # No update necessary. - if vars(current_unit) == vars(preset_unit): - return ( - gr.update(visible=False), - *( - (gr.skip(),) - * (len(vars(external_code.ControlNetUnit()).keys()) + 1) - ), - ) - - unit = preset_unit - - try: - new_control_type = infer_control_type(unit.module, unit.model) - except ValueError as e: - logger.error(e) - new_control_type = control_type - - return ( - gr.update(visible=True), - gr.update(value=new_control_type), - *[ - gr.update(value=value) if value is not None else gr.update() - for value in vars(unit).values() - ], - ) - - for element, action in ( - (self.dropdown, "change"), - (self.reset_button, "click"), - ): - getattr(element, action)( - fn=apply_preset, - inputs=[self.dropdown, control_type, *ui_states], - outputs=[self.delete_button, control_type, *ui_states], - show_progress="hidden", - ).then( - fn=lambda: gr.update(visible=False), - inputs=None, - outputs=[self.reset_button], - ) - - def save_preset(name: str, *ui_states): - if name == NEW_PRESET: - return gr.update(visible=True), gr.update(), gr.update() - - ControlNetPresetUI.save_preset( - name, external_code.ControlNetUnit(*ui_states) - ) - return ( - gr.update(), # name dialog - gr.update(choices=ControlNetPresetUI.dropdown_choices(), value=name), - gr.update(visible=False), # Reset button - ) - - self.save_button.click( - fn=save_preset, - inputs=[self.dropdown, *ui_states], - outputs=[self.name_dialog, self.dropdown, self.reset_button], - show_progress="hidden", - ).then( - fn=None, - _js=f""" - (name) => {{ - if (name === "{NEW_PRESET}") - popup(gradioApp().getElementById('{self.name_dialog.elem_id}')); - }}""", - inputs=[self.dropdown], - ) - - def delete_preset(name: str): - ControlNetPresetUI.delete_preset(name) - return gr.Dropdown.update( - choices=ControlNetPresetUI.dropdown_choices(), - value=NEW_PRESET, - ), gr.update(visible=False) - - self.delete_button.click( - fn=delete_preset, - inputs=[self.dropdown], - outputs=[self.dropdown, self.reset_button], - show_progress="hidden", - ) - - self.name_dialog.visible = False - - def save_new_preset(new_name: str, *ui_states): - if new_name == NEW_PRESET: - logger.warn(f"Cannot save preset with reserved name '{NEW_PRESET}'") - return gr.update(visible=False), gr.update() - - ControlNetPresetUI.save_preset( - new_name, external_code.ControlNetUnit(*ui_states) - ) - return gr.update(visible=False), gr.update( - choices=ControlNetPresetUI.dropdown_choices(), value=new_name - ) - - self.confirm_preset_name.click( - fn=save_new_preset, - inputs=[self.preset_name, *ui_states], - outputs=[self.name_dialog, self.dropdown], - show_progress="hidden", - ).then(fn=None, _js="closePopup") - - self.refresh_button.click( - fn=ControlNetPresetUI.refresh_preset, - inputs=None, - outputs=[self.dropdown], - show_progress="hidden", - ) - - def update_reset_button(preset_name: str, *ui_states): - if preset_name == NEW_PRESET: - return gr.update(visible=False) - - infotext = ControlNetPresetUI.presets[preset_name] - preset_unit = parse_unit(infotext) - current_unit = external_code.ControlNetUnit(*ui_states) - preset_unit.image = None - current_unit.image = None - - # Do not compare module param that are not used in preset. - for module_param in ("processor_res", "threshold_a", "threshold_b"): - if getattr(preset_unit, module_param) == -1: - setattr(current_unit, module_param, -1) - - return gr.update(visible=vars(current_unit) != vars(preset_unit)) - - for ui_state in ui_states: - if isinstance(ui_state, gr.Image): - continue - - for action in ("edit", "click", "change", "clear", "release"): - if action == "release" and not isinstance(ui_state, gr.Slider): - continue - - if hasattr(ui_state, action): - getattr(ui_state, action)( - fn=update_reset_button, - inputs=[self.dropdown, *ui_states], - outputs=[self.reset_button], - ) - - @staticmethod - def dropdown_choices() -> List[str]: - return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET] - - @staticmethod - def save_preset(name: str, unit: external_code.ControlNetUnit): - infotext = serialize_unit(unit) - with open( - os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w" - ) as f: - f.write(infotext) - - ControlNetPresetUI.presets[name] = infotext - - @staticmethod - def delete_preset(name: str): - if name not in ControlNetPresetUI.presets: - return - - del ControlNetPresetUI.presets[name] - - file = os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt") - if os.path.exists(file): - os.unlink(file) - - @staticmethod - def refresh_preset(): - ControlNetPresetUI.presets = load_presets(ControlNetPresetUI.preset_directory) - return gr.update(choices=ControlNetPresetUI.dropdown_choices()) diff --git a/scripts/controlnet_version.py b/scripts/controlnet_version.py index 34173e04d..5b8222290 100644 --- a/scripts/controlnet_version.py +++ b/scripts/controlnet_version.py @@ -1,8 +1,4 @@ -from scripts.logging import logger - -version_flag = 'v1.1.445' - -logger.info(f"ControlNet {version_flag}") +version_flag = 'v1.1.448' # A smart trick to know if user has updated as well as if user has restarted terminal. # Note that in "controlnet.py" we do NOT use "importlib.reload" to reload this "controlnet_version.py" # This means if user did not completely restart terminal, the "version_flag" will be the previous version. diff --git a/scripts/enums.py b/scripts/enums.py index 327f36431..477f38ff4 100644 --- a/scripts/enums.py +++ b/scripts/enums.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, List, NamedTuple +from typing import List, NamedTuple from functools import lru_cache @@ -224,19 +224,6 @@ class HiResFixOption(Enum): LOW_RES_ONLY = "Low res only" HIGH_RES_ONLY = "High res only" - @staticmethod - def from_value(value: Any) -> "HiResFixOption": - if isinstance(value, str) and value.startswith("HiResFixOption."): - _, field = value.split(".") - return getattr(HiResFixOption, field) - if isinstance(value, str): - return HiResFixOption(value) - elif isinstance(value, int): - return [x for x in HiResFixOption][value] - else: - assert isinstance(value, HiResFixOption) - return value - class InputMode(Enum): # Single image to a single ControlNet unit. @@ -247,3 +234,42 @@ class InputMode(Enum): # Input is a directory. 1 generation. Each generation takes N input image # from the directory. MERGE = "merge" + + +class PuLIDMode(Enum): + FIDELITY = "Fidelity" + STYLE = "Extremely style" + + +class ControlMode(Enum): + """ + The improved guess mode. + """ + + BALANCED = "Balanced" + PROMPT = "My prompt is more important" + CONTROL = "ControlNet is more important" + + +class BatchOption(Enum): + DEFAULT = "All ControlNet units for all images in a batch" + SEPARATE = "Each ControlNet unit for each image in a batch" + + +class ResizeMode(Enum): + """ + Resize modes for ControlNet input images. + """ + + RESIZE = "Just Resize" + INNER_FIT = "Crop and Resize" + OUTER_FIT = "Resize and Fill" + + def int_value(self): + if self == ResizeMode.RESIZE: + return 0 + elif self == ResizeMode.INNER_FIT: + return 1 + elif self == ResizeMode.OUTER_FIT: + return 2 + assert False, "NOTREACHED" diff --git a/scripts/infotext.py b/scripts/infotext.py index 9dbdad1d8..68202a890 100644 --- a/scripts/infotext.py +++ b/scripts/infotext.py @@ -1,60 +1,13 @@ -from typing import List, Tuple, Union - +from typing import List, Tuple +from enum import Enum import gradio as gr from modules.processing import StableDiffusionProcessing -from scripts import external_code +from internal_controlnet.external_code import ControlNetUnit from scripts.logging import logger -def field_to_displaytext(fieldname: str) -> str: - return " ".join([word.capitalize() for word in fieldname.split("_")]) - - -def displaytext_to_field(text: str) -> str: - return "_".join([word.lower() for word in text.split(" ")]) - - -def parse_value(value: str) -> Union[str, float, int, bool]: - if value in ("True", "False"): - return value == "True" - try: - return int(value) - except ValueError: - try: - return float(value) - except ValueError: - return value # Plain string. - - -def serialize_unit(unit: external_code.ControlNetUnit) -> str: - excluded_fields = external_code.ControlNetUnit.infotext_excluded_fields() - - log_value = { - field_to_displaytext(field): getattr(unit, field) - for field in vars(external_code.ControlNetUnit()).keys() - if field not in excluded_fields and getattr(unit, field) != -1 - # Note: exclude hidden slider values. - } - if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()): - logger.error(f"Unexpected tokens encountered:\n{log_value}") - return "" - - return ", ".join(f"{field}: {value}" for field, value in log_value.items()) - - -def parse_unit(text: str) -> external_code.ControlNetUnit: - return external_code.ControlNetUnit( - enabled=True, - **{ - displaytext_to_field(key): parse_value(value) - for item in text.split(",") - for (key, value) in (item.strip().split(": "),) - }, - ) - - class Infotext(object): def __init__(self) -> None: self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = [] @@ -74,11 +27,7 @@ def register_unit(self, unit_index: int, uigroup) -> None: iocomponents. """ unit_prefix = Infotext.unit_prefix(unit_index) - for field in vars(external_code.ControlNetUnit()).keys(): - # Exclude image for infotext. - if field == "image": - continue - + for field in ControlNetUnit.infotext_fields(): # Every field in ControlNetUnit should have a cooresponding # IOComponent in ControlNetUiGroup. io_component = getattr(uigroup, field) @@ -87,13 +36,11 @@ def register_unit(self, unit_index: int, uigroup) -> None: self.paste_field_names.append(component_locator) @staticmethod - def write_infotext( - units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing - ): + def write_infotext(units: List[ControlNetUnit], p: StableDiffusionProcessing): """Write infotext to `p`.""" p.extra_generation_params.update( { - Infotext.unit_prefix(i): serialize_unit(unit) + Infotext.unit_prefix(i): unit.serialize() for i, unit in enumerate(units) if unit.enabled } @@ -109,14 +56,19 @@ def on_infotext_pasted(infotext: str, results: dict) -> None: assert isinstance(v, str), f"Expect string but got {v}." try: - for field, value in vars(parse_unit(v)).items(): - if field == "image": + for field, value in vars(ControlNetUnit.parse(v)).items(): + if field not in ControlNetUnit.infotext_fields(): continue if value is None: - logger.debug(f"InfoText: Skipping {field} because value is None.") + logger.debug( + f"InfoText: Skipping {field} because value is None." + ) continue component_locator = f"{k} {field}" + if isinstance(value, Enum): + value = value.value + updates[component_locator] = value logger.debug(f"InfoText: Setting {component_locator} = {value}") except Exception as e: diff --git a/scripts/ipadapter/image_proj_models.py b/scripts/ipadapter/image_proj_models.py index 8594ac99b..d8dd12157 100644 --- a/scripts/ipadapter/image_proj_models.py +++ b/scripts/ipadapter/image_proj_models.py @@ -269,3 +269,65 @@ def forward(self, x): latents = self.proj_out(latents) return self.norm_out(latents) + + +class PuLIDEncoder(nn.Module): + def __init__(self, width=1280, context_dim=2048, num_token=5): + super().__init__() + self.num_token = num_token + self.context_dim = context_dim + h1 = min((context_dim * num_token) // 4, 1024) + h2 = min((context_dim * num_token) // 2, 1024) + self.body = nn.Sequential( + nn.Linear(width, h1), + nn.LayerNorm(h1), + nn.LeakyReLU(), + nn.Linear(h1, h2), + nn.LayerNorm(h2), + nn.LeakyReLU(), + nn.Linear(h2, context_dim * num_token), + ) + + for i in range(5): + setattr( + self, + f"mapping_{i}", + nn.Sequential( + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, context_dim), + ), + ) + + setattr( + self, + f"mapping_patch_{i}", + nn.Sequential( + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.LeakyReLU(), + nn.Linear(1024, context_dim), + ), + ) + + def forward(self, x, y): + # x shape [N, C] + x = self.body(x) + x = x.reshape(-1, self.num_token, self.context_dim) + + hidden_states = () + for i, emb in enumerate(y): + hidden_state = getattr(self, f"mapping_{i}")(emb[:, :1]) + getattr( + self, f"mapping_patch_{i}" + )(emb[:, 1:]).mean(dim=1, keepdim=True) + hidden_states += (hidden_state,) + hidden_states = torch.cat(hidden_states, dim=1) + + return torch.cat([x, hidden_states], dim=1) diff --git a/scripts/ipadapter/ipadapter_model.py b/scripts/ipadapter/ipadapter_model.py index 7314c9b2d..16d9ac4c5 100644 --- a/scripts/ipadapter/ipadapter_model.py +++ b/scripts/ipadapter/ipadapter_model.py @@ -12,6 +12,7 @@ MLPProjModel, MLPProjModelFaceId, ProjModelFaceIdPlus, + PuLIDEncoder, ) @@ -71,6 +72,7 @@ def __init__( is_faceid: bool, is_portrait: bool, is_instantid: bool, + is_pulid: bool, is_v2: bool, ): super().__init__() @@ -85,9 +87,12 @@ def __init__( self.is_v2 = is_v2 self.is_faceid = is_faceid self.is_instantid = is_instantid + self.is_pulid = is_pulid self.clip_extra_context_tokens = 16 if (self.is_plus or is_portrait) else 4 - if is_instantid: + if self.is_pulid: + self.image_proj_model = PuLIDEncoder() + elif self.is_instantid: self.image_proj_model = self.init_proj_instantid() elif is_faceid: self.image_proj_model = self.init_proj_faceid() @@ -235,6 +240,34 @@ def _get_image_embeds_instantid( self.image_proj_model(torch.zeros_like(prompt_image_emb)), ) + def _get_image_embeds_pulid(self, pulid_proj_input) -> ImageEmbed: + """Get image embeds for pulid.""" + id_cond = torch.cat( + [ + pulid_proj_input.id_ante_embedding.to( + device=self.device, dtype=torch.float32 + ), + pulid_proj_input.id_cond_vit.to( + device=self.device, dtype=torch.float32 + ), + ], + dim=-1, + ) + id_vit_hidden = [ + t.to(device=self.device, dtype=torch.float32) + for t in pulid_proj_input.id_vit_hidden + ] + return ImageEmbed( + self.image_proj_model( + id_cond, + id_vit_hidden, + ), + self.image_proj_model( + torch.zeros_like(id_cond), + [torch.zeros_like(t) for t in id_vit_hidden], + ), + ) + @staticmethod def load(state_dict: dict, model_name: str) -> IPAdapterModel: """ @@ -245,6 +278,7 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel: is_v2 = "v2" in model_name is_faceid = "faceid" in model_name is_instantid = "instant_id" in model_name + is_pulid = "pulid" in model_name.lower() is_portrait = "portrait" in model_name is_full = "proj.3.weight" in state_dict["image_proj"] is_plus = ( @@ -256,8 +290,8 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel: sdxl = cross_attention_dim == 2048 sdxl_plus = sdxl and is_plus - if is_instantid: - # InstantID does not use clip embedding. + if is_instantid or is_pulid: + # InstantID/PuLID does not use clip embedding. clip_embeddings_dim = None elif is_faceid: if is_plus: @@ -291,10 +325,13 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel: is_portrait=is_portrait, is_instantid=is_instantid, is_v2=is_v2, + is_pulid=is_pulid, ) def get_image_emb(self, preprocessor_output) -> ImageEmbed: - if self.is_instantid: + if self.is_pulid: + return self._get_image_embeds_pulid(preprocessor_output) + elif self.is_instantid: return self._get_image_embeds_instantid(preprocessor_output) elif self.is_faceid and self.is_plus: # Note: FaceID plus uses both face_embed and clip_embed. diff --git a/scripts/ipadapter/plugable_ipadapter.py b/scripts/ipadapter/plugable_ipadapter.py index b56522489..72c0e6652 100644 --- a/scripts/ipadapter/plugable_ipadapter.py +++ b/scripts/ipadapter/plugable_ipadapter.py @@ -1,8 +1,9 @@ import itertools import torch import math -from typing import Union, Dict, Optional +from typing import Union, Dict, Optional, Callable +from .pulid_attn import PuLIDAttnSetting from .ipadapter_model import ImageEmbed, IPAdapterModel from ..enums import StableDiffusionVersion, TransformerID @@ -93,7 +94,7 @@ def clear_all_ip_adapter(): class PlugableIPAdapter(torch.nn.Module): def __init__(self, ipadapter: IPAdapterModel): super().__init__() - self.ipadapter = ipadapter + self.ipadapter: IPAdapterModel = ipadapter self.disable_memory_management = True self.dtype = None self.weight: Union[float, Dict[int, float]] = 1.0 @@ -103,6 +104,7 @@ def __init__(self, ipadapter: IPAdapterModel): self.latent_width: int = 0 self.latent_height: int = 0 self.effective_region_mask = None + self.pulid_attn_setting: Optional[PuLIDAttnSetting] = None def reset(self): self.cache = {} @@ -118,6 +120,7 @@ def hook( latent_width: int, latent_height: int, effective_region_mask: Optional[torch.Tensor], + pulid_attn_setting: Optional[PuLIDAttnSetting] = None, dtype=torch.float32, ): global current_model @@ -128,6 +131,7 @@ def hook( self.latent_width = latent_width self.latent_height = latent_height self.effective_region_mask = effective_region_mask + self.pulid_attn_setting = pulid_attn_setting self.cache = {} @@ -186,7 +190,9 @@ def apply_effective_region_mask(self, out: torch.Tensor) -> torch.Tensor: # sequence_length = (latent_height * factor) * (latent_height * factor) # sequence_length = (latent_height * latent_height) * factor ^ 2 factor = math.sqrt(sequence_length / (self.latent_width * self.latent_height)) - assert factor > 0, f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}" + assert ( + factor > 0 + ), f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}" mask_h = int(self.latent_height * factor) mask_w = int(self.latent_width * factor) @@ -199,6 +205,71 @@ def apply_effective_region_mask(self, out: torch.Tensor) -> torch.Tensor: mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2]) return out * mask + def attn_eval( + self, + hidden_states: torch.Tensor, + query: torch.Tensor, + cond_uncond_image_emb: torch.Tensor, + attn_heads: int, + head_dim: int, + emb_to_k: Callable[[torch.Tensor], torch.Tensor], + emb_to_v: Callable[[torch.Tensor], torch.Tensor], + ): + if self.ipadapter.is_pulid: + assert self.pulid_attn_setting is not None + return self.pulid_attn_setting.eval( + hidden_states, + query, + cond_uncond_image_emb, + attn_heads, + head_dim, + emb_to_k, + emb_to_v, + ) + else: + return self._attn_eval_ipadapter( + hidden_states, + query, + cond_uncond_image_emb, + attn_heads, + head_dim, + emb_to_k, + emb_to_v, + ) + + def _attn_eval_ipadapter( + self, + hidden_states: torch.Tensor, + query: torch.Tensor, + cond_uncond_image_emb: torch.Tensor, + attn_heads: int, + head_dim: int, + emb_to_k: Callable[[torch.Tensor], torch.Tensor], + emb_to_v: Callable[[torch.Tensor], torch.Tensor], + ): + assert hidden_states.ndim == 3 + batch_size, sequence_length, inner_dim = hidden_states.shape + ip_k = emb_to_k(cond_uncond_image_emb) + ip_v = emb_to_v(cond_uncond_image_emb) + + ip_k, ip_v = map( + lambda t: t.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2), + (ip_k, ip_v), + ) + assert ip_k.dtype == ip_v.dtype + + # On MacOS, q can be float16 instead of float32. + # https://github.com/Mikubill/sd-webui-controlnet/issues/2208 + if query.dtype != ip_k.dtype: + ip_k = ip_k.to(dtype=query.dtype) + ip_v = ip_v.to(dtype=query.dtype) + + ip_out = torch.nn.functional.scaled_dot_product_attention( + query, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False + ) + ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim) + return ip_out + @torch.no_grad() def patch_forward(self, number: int, transformer_index: int): @torch.no_grad() @@ -220,27 +291,15 @@ def forward(attn_blk, x, q): k_key = f"{number * 2 + 1}_to_k_ip" v_key = f"{number * 2 + 1}_to_v_ip" - cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark) - ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device) - ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device) - - ip_k, ip_v = map( - lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), - (ip_k, ip_v), + ip_out = self.attn_eval( + hidden_states=x, + query=q, + cond_uncond_image_emb=self.image_emb.eval(current_model.cond_mark), + attn_heads=h, + head_dim=head_dim, + emb_to_k=lambda emb: self.call_ip(k_key, emb, device=q.device), + emb_to_v=lambda emb: self.call_ip(v_key, emb, device=q.device), ) - assert ip_k.dtype == ip_v.dtype - - # On MacOS, q can be float16 instead of float32. - # https://github.com/Mikubill/sd-webui-controlnet/issues/2208 - if q.dtype != ip_k.dtype: - ip_k = ip_k.to(dtype=q.dtype) - ip_v = ip_v.to(dtype=q.dtype) - - ip_out = torch.nn.functional.scaled_dot_product_attention( - q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False - ) - ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) - return self.apply_effective_region_mask(ip_out * weight) return forward diff --git a/scripts/ipadapter/presets.py b/scripts/ipadapter/presets.py index 275f70be5..764c83c98 100644 --- a/scripts/ipadapter/presets.py +++ b/scripts/ipadapter/presets.py @@ -166,6 +166,12 @@ def match_model(model_name: str) -> IPAdapterPreset: model="ip-adapter-faceid-portrait_sdxl", sd_version=StableDiffusionVersion.SDXL, ), + IPAdapterPreset( + name="pulid", + module="ip-adapter_pulid", + model="ip-adapter_pulid_sdxl_fp16", + sd_version=StableDiffusionVersion.SDXL, + ), ] _preset_by_model = {p.model: p for p in ipadapter_presets} diff --git a/scripts/ipadapter/pulid_attn.py b/scripts/ipadapter/pulid_attn.py new file mode 100644 index 000000000..e2823470c --- /dev/null +++ b/scripts/ipadapter/pulid_attn.py @@ -0,0 +1,94 @@ +import torch +import torch.nn.functional as F +from dataclasses import dataclass +from typing import Callable + + +@dataclass +class PuLIDAttnSetting: + num_zero: int = 0 + ortho: bool = False + ortho_v2: bool = False + + def eval( + self, + hidden_states: torch.Tensor, + query: torch.Tensor, + id_embedding: torch.Tensor, + attn_heads: int, + head_dim: int, + id_to_k: Callable[[torch.Tensor], torch.Tensor], + id_to_v: Callable[[torch.Tensor], torch.Tensor], + ): + assert hidden_states.ndim == 3 + batch_size, sequence_length, inner_dim = hidden_states.shape + + if self.num_zero == 0: + id_key = id_to_k(id_embedding).to(query.dtype) + id_value = id_to_v(id_embedding).to(query.dtype) + else: + zero_tensor = torch.zeros( + (id_embedding.size(0), self.num_zero, id_embedding.size(-1)), + dtype=id_embedding.dtype, + device=id_embedding.device, + ) + id_key = id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to( + query.dtype + ) + id_value = id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to( + query.dtype + ) + + id_key = id_key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + id_value = id_value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + id_hidden_states = F.scaled_dot_product_attention( + query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + id_hidden_states = id_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn_heads * head_dim + ) + id_hidden_states = id_hidden_states.to(query.dtype) + + if not self.ortho and not self.ortho_v2: + return id_hidden_states + elif self.ortho_v2: + orig_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + id_hidden_states = id_hidden_states.to(torch.float32) + attn_map = query @ id_key.transpose(-2, -1) + attn_mean = attn_map.softmax(dim=-1).mean(dim=1) + attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) + projection = ( + torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True) + / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) + * hidden_states + ) + orthogonal = id_hidden_states + (attn_mean - 1) * projection + return orthogonal.to(orig_dtype) + else: + orig_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + id_hidden_states = id_hidden_states.to(torch.float32) + projection = ( + torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True) + / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True) + * hidden_states + ) + orthogonal = id_hidden_states - projection + return orthogonal.to(orig_dtype) + + +PULID_SETTING_FIDELITY = PuLIDAttnSetting( + num_zero=8, + ortho=False, + ortho_v2=True, +) + +PULID_SETTING_STYLE = PuLIDAttnSetting( + num_zero=16, + ortho=True, + ortho_v2=False, +) diff --git a/scripts/preprocessor/__init__.py b/scripts/preprocessor/__init__.py index 6bbcb762f..b330e73ce 100644 --- a/scripts/preprocessor/__init__.py +++ b/scripts/preprocessor/__init__.py @@ -1,3 +1,4 @@ +from .pulid import * from .inpaint import * from .lama_inpaint import * from .ip_adapter_auto import * diff --git a/scripts/preprocessor/inpaint.py b/scripts/preprocessor/inpaint.py index 4605dc252..25874196a 100644 --- a/scripts/preprocessor/inpaint.py +++ b/scripts/preprocessor/inpaint.py @@ -1,18 +1,7 @@ -import numpy as np - +from scripts.utils import visualize_inpaint_mask from ..supported_preprocessor import Preprocessor, PreprocessorParameter -def visualize_inpaint_mask(img): - if img.ndim == 3 and img.shape[2] == 4: - result = img.copy() - mask = result[:, :, 3] - mask = 255 - mask // 2 - result[:, :, 3] = mask - return np.ascontiguousarray(result.copy()) - return img - - class PreprocessorInpaint(Preprocessor): def __init__(self): super().__init__(name="inpaint") @@ -23,9 +12,6 @@ def __init__(self): self.accepts_mask = True self.requires_mask = True - def get_display_image(self, input_image: np.ndarray, result): - return visualize_inpaint_mask(result) - def __call__( self, input_image, @@ -35,7 +21,10 @@ def __call__( slider_3=None, **kwargs ): - return input_image + return Preprocessor.Result( + value=input_image, + display_images=visualize_inpaint_mask(input_image)[None, :, :, :], + ) class PreprocessorInpaintOnly(Preprocessor): @@ -47,9 +36,6 @@ def __init__(self): self.accepts_mask = True self.requires_mask = True - def get_display_image(self, input_image: np.ndarray, result): - return visualize_inpaint_mask(result) - def __call__( self, input_image, @@ -59,7 +45,10 @@ def __call__( slider_3=None, **kwargs ): - return input_image + return Preprocessor.Result( + value=input_image, + display_images=visualize_inpaint_mask(input_image)[None, :, :, :], + ) Preprocessor.add_supported_preprocessor(PreprocessorInpaint()) diff --git a/scripts/preprocessor/lama_inpaint.py b/scripts/preprocessor/lama_inpaint.py index 33aff60bf..1cd1c521c 100644 --- a/scripts/preprocessor/lama_inpaint.py +++ b/scripts/preprocessor/lama_inpaint.py @@ -2,7 +2,7 @@ import numpy as np from ..supported_preprocessor import Preprocessor, PreprocessorParameter -from ..utils import resize_image_with_pad +from ..utils import resize_image_with_pad, visualize_inpaint_mask class PreprocessorLamaInpaint(Preprocessor): @@ -15,12 +15,6 @@ def __init__(self): self.accepts_mask = True self.requires_mask = True - def get_display_image(self, input_image: np.ndarray, result: np.ndarray): - """For lama inpaint, display image should not contain mask.""" - assert result.ndim == 3 - assert result.shape[2] == 4 - return result[:, :, :3] - def __call__( self, input_image, @@ -56,7 +50,13 @@ def __call__( fin_color = fin_color.clip(0, 255).astype(np.uint8) result = np.concatenate([fin_color, raw_mask], axis=2) - return result + return Preprocessor.Result( + value=result, + display_images=[ + result[:, :, :3], + visualize_inpaint_mask(result), + ], + ) Preprocessor.add_supported_preprocessor(PreprocessorLamaInpaint()) diff --git a/scripts/preprocessor/legacy/legacy_preprocessors.py b/scripts/preprocessor/legacy/legacy_preprocessors.py index 902e6c9d0..7c5e1c873 100644 --- a/scripts/preprocessor/legacy/legacy_preprocessors.py +++ b/scripts/preprocessor/legacy/legacy_preprocessors.py @@ -93,7 +93,7 @@ def unload(self): def __call__( self, input_image, - resolution, + resolution=512, slider_1=None, slider_2=None, slider_3=None, diff --git a/scripts/preprocessor/model_free_preprocessors.py b/scripts/preprocessor/model_free_preprocessors.py index 6b333e359..54e6c1f6d 100644 --- a/scripts/preprocessor/model_free_preprocessors.py +++ b/scripts/preprocessor/model_free_preprocessors.py @@ -93,7 +93,7 @@ class PreprocessorBlurGaussian(Preprocessor): def __init__(self): super().__init__(name="blur_gaussian") self.slider_1 = PreprocessorParameter( - label="Sigma", minimum=64, maximum=2048, value=512 + label="Sigma", minimum=0.01, maximum=64.0, value=9.0 ) self.tags = ["Tile"] diff --git a/scripts/preprocessor/pulid.py b/scripts/preprocessor/pulid.py new file mode 100644 index 000000000..a46f91290 --- /dev/null +++ b/scripts/preprocessor/pulid.py @@ -0,0 +1,169 @@ +# https://github.com/ToTheBeginning/PuLID + +import torch +import cv2 +import numpy as np +from typing import Optional, List +from dataclasses import dataclass +from facexlib.parsing import init_parsing_model +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from torchvision.transforms.functional import normalize + +from ..supported_preprocessor import Preprocessor, PreprocessorParameter +from scripts.utils import npimg2tensor, tensor2npimg, resize_image_with_pad + + +def to_gray(img): + x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] + x = x.repeat(1, 3, 1, 1) + return x + + +class PreprocessorFaceXLib(Preprocessor): + def __init__(self): + super().__init__(name="facexlib") + self.tags = [] + self.slider_resolution = PreprocessorParameter(visible=False) + self.model: Optional[FaceRestoreHelper] = None + + def load_model(self): + if self.model is None: + self.model = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + device=self.device, + ) + self.model.face_parse = init_parsing_model( + model_name="bisenet", device=self.device + ) + self.model.face_parse.to(device=self.device) + self.model.face_det.to(device=self.device) + return self.model + + def unload(self) -> bool: + """@Override""" + if self.model is not None: + self.model.face_parse.to(device="cpu") + self.model.face_det.to(device="cpu") + return True + return False + + def __call__( + self, + input_image, + resolution=512, + slider_1=None, + slider_2=None, + slider_3=None, + input_mask=None, + return_tensor=False, + **kwargs + ): + """ + @Override + Returns black and white face features image with background removed. + """ + self.load_model() + self.model.clean_all() + input_image, _ = resize_image_with_pad(input_image, resolution) + # using facexlib to detect and align face + image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) + self.model.read_image(image_bgr) + self.model.get_face_landmarks_5(only_center_face=True) + self.model.align_warp_face() + if len(self.model.cropped_faces) == 0: + raise RuntimeError("facexlib align face fail") + align_face = self.model.cropped_faces[0] + align_face_rgb = cv2.cvtColor(align_face, cv2.COLOR_BGR2RGB) + input = npimg2tensor(align_face_rgb) + input = input.to(self.device) + parsing_out = self.model.face_parse( + normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + )[0] + parsing_out = parsing_out.argmax(dim=1, keepdim=True) + bg_label = [0, 16, 18, 7, 8, 9, 14, 15] + bg = sum(parsing_out == i for i in bg_label).bool() + white_image = torch.ones_like(input) + # only keep the face features + face_features_image = torch.where(bg, white_image, to_gray(input)) + if return_tensor: + return face_features_image + else: + return tensor2npimg(face_features_image) + + +@dataclass +class PuLIDProjInput: + id_ante_embedding: torch.Tensor + id_cond_vit: torch.Tensor + id_vit_hidden: List[torch.Tensor] + + +class PreprocessorPuLID(Preprocessor): + """PuLID preprocessor.""" + + def __init__(self): + super().__init__(name="ip-adapter_pulid") + self.tags = ["IP-Adapter"] + self.slider_resolution = PreprocessorParameter(visible=False) + self.returns_image = False + self.preprocessors_deps = [ + "facexlib", + "instant_id_face_embedding", + "EVA02-CLIP-L-14-336", + ] + + def facexlib_detect(self, input_image: np.ndarray) -> torch.Tensor: + facexlib_preprocessor = Preprocessor.get_preprocessor("facexlib") + return facexlib_preprocessor(input_image, return_tensor=True) + + def insightface_antelopev2_detect(self, input_image: np.ndarray) -> torch.Tensor: + antelopev2_preprocessor = Preprocessor.get_preprocessor( + "instant_id_face_embedding" + ) + return antelopev2_preprocessor(input_image) + + def unload(self) -> bool: + unloaded = False + for p_name in self.preprocessors_deps: + p = Preprocessor.get_preprocessor(p_name) + if p is not None: + unloaded = unloaded or p.unload() + return unloaded + + def __call__( + self, + input_image, + resolution, + slider_1=None, + slider_2=None, + slider_3=None, + input_mask=None, + **kwargs + ) -> Preprocessor.Result: + id_ante_embedding = self.insightface_antelopev2_detect(input_image) + if id_ante_embedding.ndim == 1: + id_ante_embedding = id_ante_embedding.unsqueeze(0) + + face_features_image = self.facexlib_detect(input_image) + evaclip_preprocessor = Preprocessor.get_preprocessor("EVA02-CLIP-L-14-336") + assert ( + evaclip_preprocessor is not None + ), "EVA02-CLIP-L-14-336 preprocessor not found! Please install sd-webui-controlnet-evaclip" + r = evaclip_preprocessor(face_features_image) + + return Preprocessor.Result( + value=PuLIDProjInput( + id_ante_embedding=id_ante_embedding, + id_cond_vit=r.id_cond_vit, + id_vit_hidden=r.id_vit_hidden, + ), + display_images=[tensor2npimg(face_features_image)], + ) + + +Preprocessor.add_supported_preprocessor(PreprocessorFaceXLib()) +Preprocessor.add_supported_preprocessor(PreprocessorPuLID()) diff --git a/scripts/supported_preprocessor.py b/scripts/supported_preprocessor.py index caf5a6a78..473d6203c 100644 --- a/scripts/supported_preprocessor.py +++ b/scripts/supported_preprocessor.py @@ -4,7 +4,7 @@ import numpy as np import torch -from modules import shared +from modules import shared, devices from scripts.logging import logger from scripts.utils import ndarray_lru_cache @@ -101,6 +101,7 @@ class Preprocessor(ABC): use_soft_projection_in_hr_fix = False expand_mask_when_resize_and_fill = False model: Optional[torch.nn.Module] = None + device = devices.get_device_for("controlnet") all_processors: ClassVar[Dict[str, "Preprocessor"]] = {} all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {} @@ -183,18 +184,19 @@ def unload_unused(cls, active_processors: Set["Preprocessor"]): class Result(NamedTuple): value: Any - # The display image shown on UI. - display_image: np.ndarray - - def get_display_image(self, input_image: np.ndarray, result): - return result if self.returns_image else input_image + # The display images shown on UI. + display_images: List[np.ndarray] def cached_call(self, input_image, *args, **kwargs) -> "Preprocessor.Result": """The function exposed that also returns an image for display.""" result = self._cached_call(input_image, *args, **kwargs) - return Preprocessor.Result( - value=result, display_image=self.get_display_image(input_image, result) - ) + if isinstance(result, Preprocessor.Result): + return result + else: + return Preprocessor.Result( + value=result, + display_images=[result if self.returns_image else input_image], + ) @ndarray_lru_cache(max_size=CACHE_SIZE) def _cached_call(self, *args, **kwargs): diff --git a/scripts/utils.py b/scripts/utils.py index c26750f14..e660279a9 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -1,3 +1,4 @@ +from einops import rearrange import torch import os import functools @@ -105,8 +106,9 @@ def wrapper(*args, **kwargs): class TimeMeta(type): - """ Metaclass to record execution time on all methods of the - child class. """ + """Metaclass to record execution time on all methods of the + child class.""" + def __new__(cls, name, bases, attrs): for attr_name, attr_value in attrs.items(): if callable(attr_value): @@ -161,7 +163,9 @@ def read_image(img_path: str) -> str: return encoded_image -def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]: +def read_image_dir( + img_dir: str, suffixes=(".png", ".jpg", ".jpeg", ".webp") +) -> List[str]: """Try read all images in given img_dir.""" images = [] for filename in os.listdir(img_dir): @@ -175,7 +179,7 @@ def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> def align_dim_latent(x: int) -> int: - """ Align the pixel dimension (w/h) to latent dimension. + """Align the pixel dimension (w/h) to latent dimension. Stable diffusion 1:8 ratio for latent/pixel, i.e., 1 latent unit == 8 pixel unit.""" return (x // 8) * 8 @@ -203,9 +207,34 @@ def resize_image_with_pad(img: np.ndarray, resolution: int): W_target = int(np.round(float(W_raw) * k)) img = cv2.resize(img, (W_target, H_target), interpolation=interpolation) H_pad, W_pad = pad64(H_target), pad64(W_target) - img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge') + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode="edge") def remove_pad(x): return safer_memory(x[:H_target, :W_target]) - return safer_memory(img_padded), remove_pad \ No newline at end of file + return safer_memory(img_padded), remove_pad + + +def npimg2tensor(img: np.ndarray) -> torch.Tensor: + """Convert numpy img ([H, W, C]) to tensor ([1, C, H, W])""" + return rearrange(torch.from_numpy(img).float() / 255.0, "h w c -> 1 c h w") + + +def tensor2npimg(t: torch.Tensor) -> np.ndarray: + """Convert tensor ([1, C, H, W]) to numpy RGB img ([H, W, C])""" + return ( + (rearrange(t, "1 c h w -> h w c") * 255.0) + .to(dtype=torch.uint8) + .cpu() + .numpy() + ) + + +def visualize_inpaint_mask(img): + if img.ndim == 3 and img.shape[2] == 4: + result = img.copy() + mask = result[:, :, 3] + mask = 255 - mask // 2 + result[:, :, 3] = mask + return np.ascontiguousarray(result.copy()) + return img diff --git a/tests/cn_script/batch_hijack_test.py b/tests/cn_script/batch_hijack_test.py index 0f68fe5bc..b8c1cc444 100644 --- a/tests/cn_script/batch_hijack_test.py +++ b/tests/cn_script/batch_hijack_test.py @@ -1,3 +1,4 @@ +import numpy as np import unittest.mock import importlib from typing import Any @@ -6,13 +7,18 @@ from modules import processing, scripts, shared -from scripts import controlnet, external_code, batch_hijack +from internal_controlnet.external_code import ControlNetUnit +from scripts import controlnet, batch_hijack batch_hijack.instance.undo_hijack() original_process_images_inner = processing.process_images_inner +def create_unit(**kwargs) -> ControlNetUnit: + return ControlNetUnit(enabled=True, **kwargs) + + class TestBatchHijack(unittest.TestCase): @unittest.mock.patch('modules.script_callbacks.on_script_unloaded') def setUp(self, on_script_unloaded_mock): @@ -58,9 +64,18 @@ def assert_get_cn_batches_works(self, batch_images_list): is_cn_batch, batches, output_dir, _ = batch_hijack.get_cn_batches(self.p) batch_hijack.instance.dispatch_callbacks(batch_hijack.instance.process_batch_callbacks, self.p, batches, output_dir) - batch_units = [unit for unit in self.p.script_args if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH] + batch_units = [ + unit + for unit in self.p.script_args + if getattr(unit, 'input_mode', batch_hijack.InputMode.SIMPLE) == batch_hijack.InputMode.BATCH + ] + # Convert iterator to list to avoid double eval of iterator exhausting + # the iterator in following checks. + for unit in batch_units: + unit.batch_images = list(unit.batch_images) + if batch_units: - self.assertEqual(min(len(list(unit.batch_images)) for unit in batch_units), len(batches)) + self.assertEqual(min(len(unit.batch_images) for unit in batch_units), len(batches)) else: self.assertEqual(1, len(batches)) @@ -73,15 +88,15 @@ def test_get_cn_batches__empty(self): self.assertEqual(is_batch, False) def test_get_cn_batches__1_simple(self): - self.p.script_args.append(external_code.ControlNetUnit(image=get_dummy_image())) + self.p.script_args.append(create_unit(image=get_dummy_image())) self.assert_get_cn_batches_works([ - [self.p.script_args[0].image], + [get_dummy_image()], ]) def test_get_cn_batches__2_simples(self): self.p.script_args.extend([ - external_code.ControlNetUnit(image=get_dummy_image(0)), - external_code.ControlNetUnit(image=get_dummy_image(1)), + create_unit(image=get_dummy_image(0)), + create_unit(image=get_dummy_image(1)), ]) self.assert_get_cn_batches_works([ [get_dummy_image(0)], @@ -90,7 +105,7 @@ def test_get_cn_batches__2_simples(self): def test_get_cn_batches__1_batch(self): self.p.script_args.extend([ - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), @@ -107,14 +122,14 @@ def test_get_cn_batches__1_batch(self): def test_get_cn_batches__2_batches(self): self.p.script_args.extend([ - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), get_dummy_image(1), ], ), - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(2), @@ -135,8 +150,8 @@ def test_get_cn_batches__2_batches(self): def test_get_cn_batches__2_mixed(self): self.p.script_args.extend([ - external_code.ControlNetUnit(image=get_dummy_image(0)), - controlnet.UiControlNetUnit( + create_unit(image=get_dummy_image(0)), + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(1), @@ -157,8 +172,8 @@ def test_get_cn_batches__2_mixed(self): def test_get_cn_batches__3_mixed(self): self.p.script_args.extend([ - external_code.ControlNetUnit(image=get_dummy_image(0)), - controlnet.UiControlNetUnit( + create_unit(image=get_dummy_image(0)), + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(1), @@ -166,7 +181,7 @@ def test_get_cn_batches__3_mixed(self): get_dummy_image(3), ], ), - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(4), @@ -242,14 +257,14 @@ def test_process_images_no_units_forwards(self): def test_process_images__only_simple_units__forwards(self): self.p.script_args = [ - external_code.ControlNetUnit(image=get_dummy_image()), - external_code.ControlNetUnit(image=get_dummy_image()), + create_unit(image=get_dummy_image()), + create_unit(image=get_dummy_image()), ] self.assert_process_images_hijack_called(batch_count=0) def test_process_images__1_batch_1_unit__runs_1_batch(self): self.p.script_args = [ - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(), @@ -260,7 +275,7 @@ def test_process_images__1_batch_1_unit__runs_1_batch(self): def test_process_images__2_batches_1_unit__runs_2_batches(self): self.p.script_args = [ - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), @@ -273,7 +288,7 @@ def test_process_images__2_batches_1_unit__runs_2_batches(self): def test_process_images__8_batches_1_unit__runs_8_batches(self): batch_count = 8 self.p.script_args = [ - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[get_dummy_image(i) for i in range(batch_count)] ), @@ -282,11 +297,11 @@ def test_process_images__8_batches_1_unit__runs_8_batches(self): def test_process_images__1_batch_2_units__runs_1_batch(self): self.p.script_args = [ - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[get_dummy_image(0)] ), - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[get_dummy_image(1)] ), @@ -295,14 +310,14 @@ def test_process_images__1_batch_2_units__runs_1_batch(self): def test_process_images__2_batches_2_units__runs_2_batches(self): self.p.script_args = [ - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), get_dummy_image(1), ], ), - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(2), @@ -314,7 +329,7 @@ def test_process_images__2_batches_2_units__runs_2_batches(self): def test_process_images__3_batches_2_mixed_units__runs_3_batches(self): self.p.script_args = [ - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ get_dummy_image(0), @@ -322,7 +337,7 @@ def test_process_images__3_batches_2_mixed_units__runs_3_batches(self): get_dummy_image(2), ], ), - controlnet.UiControlNetUnit( + create_unit( input_mode=batch_hijack.InputMode.SIMPLE, image=get_dummy_image(3), ), diff --git a/tests/cn_script/cn_script_test.py b/tests/cn_script/cn_script_test.py index 47a9ae2a2..1ff7d526d 100644 --- a/tests/cn_script/cn_script_test.py +++ b/tests/cn_script/cn_script_test.py @@ -7,8 +7,9 @@ utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils") -from scripts import external_code +from scripts.enums import ResizeMode from scripts.controlnet import prepare_mask, Script, set_numpy_seed +from internal_controlnet.external_code import ControlNetUnit from modules import processing @@ -117,16 +118,14 @@ class TestScript(unittest.TestCase): "AAAAAAAAAAAAAAAAAAAAAAAA/wZOlAAB5tU+nAAAAABJRU5ErkJggg==" ) - sample_np_image = np.array( - [[100, 200, 50], [150, 75, 225], [30, 120, 180]], dtype=np.uint8 - ) + sample_np_image = np.zeros(shape=[8, 8, 3], dtype=np.uint8) def test_choose_input_image(self): with self.subTest(name="no image"): with self.assertRaises(ValueError): Script.choose_input_image( p=processing.StableDiffusionProcessing(), - unit=external_code.ControlNetUnit(), + unit=ControlNetUnit(), idx=0, ) @@ -134,30 +133,30 @@ def test_choose_input_image(self): _, resize_mode = Script.choose_input_image( p=MockImg2ImgProcessing( init_images=[TestScript.sample_np_image], - resize_mode=external_code.ResizeMode.OUTER_FIT, + resize_mode=ResizeMode.OUTER_FIT, ), - unit=external_code.ControlNetUnit( - image=TestScript.sample_base64_image, + unit=ControlNetUnit( + image=TestScript.sample_np_image, module="none", - resize_mode=external_code.ResizeMode.INNER_FIT, + resize_mode=ResizeMode.INNER_FIT, ), idx=0, ) - self.assertEqual(resize_mode, external_code.ResizeMode.INNER_FIT) + self.assertEqual(resize_mode, ResizeMode.INNER_FIT) with self.subTest(name="A1111 input"): _, resize_mode = Script.choose_input_image( p=MockImg2ImgProcessing( init_images=[TestScript.sample_np_image], - resize_mode=external_code.ResizeMode.OUTER_FIT, + resize_mode=ResizeMode.OUTER_FIT, ), - unit=external_code.ControlNetUnit( + unit=ControlNetUnit( module="none", - resize_mode=external_code.ResizeMode.INNER_FIT, + resize_mode=ResizeMode.INNER_FIT, ), idx=0, ) - self.assertEqual(resize_mode, external_code.ResizeMode.OUTER_FIT) + self.assertEqual(resize_mode, ResizeMode.OUTER_FIT) if __name__ == "__main__": diff --git a/tests/cn_script/infotext_test.py b/tests/cn_script/infotext_test.py deleted file mode 100644 index 61a7002ee..000000000 --- a/tests/cn_script/infotext_test.py +++ /dev/null @@ -1,34 +0,0 @@ -import unittest -import importlib - -utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils") - -from scripts.infotext import parse_unit -from scripts.external_code import ControlNetUnit - - -class TestInfotext(unittest.TestCase): - def test_parsing(self): - infotext = ( - "Module: inpaint_only+lama, Model: control_v11p_sd15_inpaint [ebff9138], Weight: 1, " - "Resize Mode: Resize and Fill, Low Vram: False, Guidance Start: 0, Guidance End: 1, " - "Pixel Perfect: True, Control Mode: Balanced, Hr Option: Both, Save Detected Map: True" - ) - self.assertEqual( - vars( - ControlNetUnit( - module="inpaint_only+lama", - model="control_v11p_sd15_inpaint [ebff9138]", - weight=1, - resize_mode="Resize and Fill", - low_vram=False, - guidance_start=0, - guidance_end=1, - pixel_perfect=True, - control_mode="Balanced", - hr_option="Both", - save_detected_map=True, - ) - ), - vars(parse_unit(infotext)), - ) diff --git a/tests/external_code_api/external_code_test.py b/tests/external_code_api/external_code_test.py index b2b4101d2..e7c513088 100644 --- a/tests/external_code_api/external_code_test.py +++ b/tests/external_code_api/external_code_test.py @@ -9,6 +9,8 @@ from copy import copy from scripts import external_code from scripts import controlnet +from scripts.enums import ResizeMode +from internal_controlnet.external_code import ControlNetUnit from modules import scripts, ui, shared @@ -48,79 +50,15 @@ def test_empty_resizes_min_args(self): def test_empty_resizes_extra_args(self): extra_models = 1 - self.cn_units = [external_code.ControlNetUnit()] * (self.max_models + extra_models) + self.cn_units = [ControlNetUnit()] * (self.max_models + extra_models) self.assert_update_in_place_ok() -class TestControlNetUnitConversion(unittest.TestCase): - def setUp(self): - self.dummy_image = 'base64...' - self.input = {} - self.expected = external_code.ControlNetUnit() - - def assert_converts_to_expected(self): - self.assertEqual(vars(external_code.to_processing_unit(self.input)), vars(self.expected)) - - def test_empty_dict_works(self): - self.assert_converts_to_expected() - - def test_image_works(self): - self.input = { - 'image': self.dummy_image - } - self.expected = external_code.ControlNetUnit(image=self.dummy_image) - self.assert_converts_to_expected() - - def test_image_alias_works(self): - self.input = { - 'input_image': self.dummy_image - } - self.expected = external_code.ControlNetUnit(image=self.dummy_image) - self.assert_converts_to_expected() - - def test_masked_image_works(self): - self.input = { - 'image': self.dummy_image, - 'mask': self.dummy_image, - } - self.expected = external_code.ControlNetUnit(image={'image': self.dummy_image, 'mask': self.dummy_image}) - self.assert_converts_to_expected() - - -class TestControlNetUnitImageToDict(unittest.TestCase): - def setUp(self): - self.dummy_image = utils.readImage("test/test_files/img2img_basic.png") - self.input = external_code.ControlNetUnit() - self.expected_image = external_code.to_base64_nparray(self.dummy_image) - self.expected_mask = external_code.to_base64_nparray(self.dummy_image) - - def assert_dict_is_valid(self): - actual_dict = controlnet.image_dict_from_any(self.input.image) - self.assertEqual(actual_dict['image'].tolist(), self.expected_image.tolist()) - self.assertEqual(actual_dict['mask'].tolist(), self.expected_mask.tolist()) - - def test_none(self): - self.assertEqual(controlnet.image_dict_from_any(self.input.image), None) - - def test_image_without_mask(self): - self.input.image = self.dummy_image - self.expected_mask = np.zeros_like(self.expected_image, dtype=np.uint8) - self.assert_dict_is_valid() - - def test_masked_image_tuple(self): - self.input.image = (self.dummy_image, self.dummy_image,) - self.assert_dict_is_valid() - - def test_masked_image_dict(self): - self.input.image = {'image': self.dummy_image, 'mask': self.dummy_image} - self.assert_dict_is_valid() - - class TestPixelPerfectResolution(unittest.TestCase): def test_outer_fit(self): image = np.zeros((100, 100, 3)) target_H, target_W = 50, 100 - resize_mode = external_code.ResizeMode.OUTER_FIT + resize_mode = ResizeMode.OUTER_FIT result = external_code.pixel_perfect_resolution(image, target_H, target_W, resize_mode) expected = 50 # manually computed expected result self.assertEqual(result, expected) @@ -128,43 +66,11 @@ def test_outer_fit(self): def test_inner_fit(self): image = np.zeros((100, 100, 3)) target_H, target_W = 50, 100 - resize_mode = external_code.ResizeMode.INNER_FIT + resize_mode = ResizeMode.INNER_FIT result = external_code.pixel_perfect_resolution(image, target_H, target_W, resize_mode) expected = 100 # manually computed expected result self.assertEqual(result, expected) -class TestGetAllUnitsFrom(unittest.TestCase): - def test_none(self): - self.assertListEqual(external_code.get_all_units_from([None]), []) - - def test_bool(self): - self.assertListEqual(external_code.get_all_units_from([True]), []) - - def test_inheritance(self): - class Foo(external_code.ControlNetUnit): - def __init__(self): - super().__init__(self) - self.bar = 'a' - - foo = Foo() - self.assertListEqual(external_code.get_all_units_from([foo]), [foo]) - - def test_dict(self): - units = external_code.get_all_units_from([{}]) - self.assertGreater(len(units), 0) - self.assertIsInstance(units[0], external_code.ControlNetUnit) - - def test_unitlike(self): - class Foo(object): - """ bar """ - - foo = Foo() - for key in vars(external_code.ControlNetUnit()).keys(): - setattr(foo, key, True) - setattr(foo, 'bar', False) - self.assertListEqual(external_code.get_all_units_from([foo]), [foo]) - - if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/external_code_api/importlib_reload_test.py b/tests/external_code_api/importlib_reload_test.py deleted file mode 100644 index 68ba7fb89..000000000 --- a/tests/external_code_api/importlib_reload_test.py +++ /dev/null @@ -1,24 +0,0 @@ -import unittest -import importlib -utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils') - - -from scripts import external_code - - -class TestImportlibReload(unittest.TestCase): - def setUp(self): - self.ControlNetUnit = external_code.ControlNetUnit - - def test_reload_does_not_redefine(self): - importlib.reload(external_code) - NewControlNetUnit = external_code.ControlNetUnit - self.assertEqual(self.ControlNetUnit, NewControlNetUnit) - - def test_force_import_does_not_redefine(self): - external_code_copy = importlib.import_module('extensions.sd-webui-controlnet.scripts.external_code', 'external_code') - self.assertEqual(self.ControlNetUnit, external_code_copy.ControlNetUnit) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/external_code_api/script_args_test.py b/tests/external_code_api/script_args_test.py deleted file mode 100644 index 99c710260..000000000 --- a/tests/external_code_api/script_args_test.py +++ /dev/null @@ -1,34 +0,0 @@ -import unittest -import importlib -utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils') - - -from scripts import external_code - - -class TestGetAllUnitsFrom(unittest.TestCase): - def setUp(self): - self.control_unit = { - "module": "none", - "model": utils.get_model("canny"), - "image": utils.readImage("test/test_files/img2img_basic.png"), - "resize_mode": 1, - "low_vram": False, - "processor_res": 64, - "control_mode": external_code.ControlMode.BALANCED.value, - } - self.object_unit = external_code.ControlNetUnit(**self.control_unit) - - def test_empty_converts(self): - script_args = [] - units = external_code.get_all_units_from(script_args) - self.assertListEqual(units, []) - - def test_object_forwards(self): - script_args = [self.object_unit] - units = external_code.get_all_units_from(script_args) - self.assertListEqual(units, [self.object_unit]) - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/tests/web_api/detect_test.py b/tests/web_api/detect_test.py index dc37f33de..f4021aca3 100644 --- a/tests/web_api/detect_test.py +++ b/tests/web_api/detect_test.py @@ -75,7 +75,7 @@ def detect_template(payload, output_name: str, status: int = 200): def test_detect_all_modules(module: str): payload = dict( controlnet_input_images=[realistic_girl_face_img], - controlnet_masks=[mask_img], + controlnet_module=module, ) detect_template(payload, f"detect_{module}") @@ -143,16 +143,16 @@ def test_detect_default_param(): dict( controlnet_input_images=[realistic_girl_face_img], controlnet_module="canny", # Canny does not require model download. - controlnet_threshold_a=-1, - controlnet_threshold_b=-1, - controlnet_processor_res=-1, + controlnet_threshold_a=-100, + controlnet_threshold_b=-100, + controlnet_processor_res=-100, ), "default_param", ) assert log_context.is_in_console_logs( [ - "[canny.processor_res] Invalid value(-1), using default value 512.", - "[canny.threshold_a] Invalid value(-1.0), using default value 100.", - "[canny.threshold_b] Invalid value(-1.0), using default value 200.", + "[canny.processor_res] Invalid value(-100), using default value 512.", + "[canny.threshold_a] Invalid value(-100.0), using default value 100.", + "[canny.threshold_b] Invalid value(-100.0), using default value 200.", ] ) diff --git a/tests/web_api/full_coverage/template.py b/tests/web_api/full_coverage/template.py index ecccc74dc..333c1d3f7 100644 --- a/tests/web_api/full_coverage/template.py +++ b/tests/web_api/full_coverage/template.py @@ -169,16 +169,16 @@ def expect_same_image(img1, img2, diff_img_path: str) -> bool: default_unit = { - "control_mode": 0, + "control_mode": "Balanced", "enabled": True, "guidance_end": 1, "guidance_start": 0, "low_vram": False, "pixel_perfect": True, "processor_res": 512, - "resize_mode": 1, - "threshold_a": 64, - "threshold_b": 64, + "resize_mode": "Crop and Resize", + "threshold_a": -1, + "threshold_b": -1, "weight": 1, } diff --git a/tests/web_api/generation_test.py b/tests/web_api/generation_test.py index 0dd39357d..8da3d1c4e 100644 --- a/tests/web_api/generation_test.py +++ b/tests/web_api/generation_test.py @@ -87,12 +87,13 @@ def test_invalid_param(gen_type, param_name): f"test_invalid_param{(gen_type, param_name)}", gen_type, payload_overrides={}, - unit_overrides={param_name: -1}, + unit_overrides={param_name: -100}, input_image=girl_img, ).exec() + number = "-100" if param_name == "processor_res" else "-100.0" assert log_context.is_in_console_logs( [ - f"[canny.{param_name}] Invalid value(-1), using default value", + f"[canny.{param_name}] Invalid value({number}), using default value", ] ) @@ -171,7 +172,7 @@ def test_reference(): "model": "None", }, input_image=girl_img, - ).exec() + ).exec(result_only=False) def test_advanced_weighting(): @@ -192,7 +193,7 @@ def test_hr_option(): "enable_hr": True, "denoising_strength": 0.75, }, - unit_overrides={"hr_option": "HiResFixOption.BOTH"}, + unit_overrides={"hr_option": "Both"}, input_image=girl_img, ).exec(expected_output_num=3) @@ -203,7 +204,7 @@ def test_hr_option_default(): "test_hr_option_default", "txt2img", payload_overrides={"enable_hr": False}, - unit_overrides={"hr_option": "HiResFixOption.BOTH"}, + unit_overrides={"hr_option": "Both"}, input_image=girl_img, ).exec(expected_output_num=2) diff --git a/tests/web_api/template.py b/tests/web_api/template.py index 2f3a56344..aa8423dab 100644 --- a/tests/web_api/template.py +++ b/tests/web_api/template.py @@ -295,15 +295,15 @@ def get_model(model_name: str) -> str: default_unit = { - "control_mode": 0, + "control_mode": "Balanced", "enabled": True, "guidance_end": 1, "guidance_start": 0, "pixel_perfect": True, "processor_res": 512, - "resize_mode": 1, - "threshold_a": 64, - "threshold_b": 64, + "resize_mode": "Crop and Resize", + "threshold_a": -1, + "threshold_b": -1, "weight": 1, "module": "canny", "model": get_model("sd15_canny"), diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/args_test.py b/unit_tests/args_test.py new file mode 100644 index 000000000..de9158220 --- /dev/null +++ b/unit_tests/args_test.py @@ -0,0 +1,241 @@ +import pytest +import torch +import numpy as np +from dataclasses import dataclass + +from internal_controlnet.args import ControlNetUnit + +H = W = 128 + +img1 = np.ones(shape=[H, W, 3], dtype=np.uint8) +img2 = np.ones(shape=[H, W, 3], dtype=np.uint8) * 2 +mask_diff = np.ones(shape=[H - 1, W - 1, 3], dtype=np.uint8) * 2 +mask_2d = np.ones(shape=[H, W]) +img_bad_channel = np.ones(shape=[H, W, 2], dtype=np.uint8) * 2 +img_bad_dim = np.ones(shape=[1, H, W, 3], dtype=np.uint8) * 2 +ui_img_diff = np.ones(shape=[H - 1, W - 1, 4], dtype=np.uint8) * 2 +ui_img = np.ones(shape=[H, W, 4], dtype=np.uint8) +tensor1 = torch.zeros(size=[1, 1], dtype=torch.float16) + + +@pytest.fixture(scope="module") +def set_cls_funcs(): + ControlNetUnit.cls_match_model = lambda s: s in { + "None", + "model1", + "model2", + "control_v11p_sd15_inpaint [ebff9138]", + } + ControlNetUnit.cls_match_module = lambda s: s in { + "none", + "module1", + "inpaint_only+lama", + } + ControlNetUnit.cls_decode_base64 = lambda s: { + "b64img1": img1, + "b64img2": img2, + "b64mask_diff": mask_diff, + }[s] + ControlNetUnit.cls_torch_load_base64 = lambda s: { + "b64tensor1": tensor1, + }[s] + ControlNetUnit.cls_get_preprocessor = lambda s: { + "module1": MockPreprocessor(), + "none": MockPreprocessor(), + "inpaint_only+lama": MockPreprocessor(), + }[s] + + +def test_module_invalid(set_cls_funcs): + with pytest.raises(ValueError) as excinfo: + ControlNetUnit(module="foo") + + assert "module(foo) not found in supported modules." in str(excinfo.value) + + +def test_module_valid(set_cls_funcs): + ControlNetUnit(module="module1") + + +def test_model_invalid(set_cls_funcs): + with pytest.raises(ValueError) as excinfo: + ControlNetUnit(model="foo") + + assert "model(foo) not found in supported models." in str(excinfo.value) + + +def test_model_valid(set_cls_funcs): + ControlNetUnit(model="model1") + + +@pytest.mark.parametrize( + "d", + [ + # API + dict(image={"image": "b64img1"}), + dict(image={"image": "b64img1", "mask": "b64img2"}), + dict(image=["b64img1", "b64img2"]), + dict(image=("b64img1", "b64img2")), + dict(image=[{"image": "b64img1", "mask": "b64img2"}]), + dict(image=[{"image": "b64img1"}]), + dict(image=[{"image": "b64img1", "mask": None}]), + dict( + image=[ + {"image": "b64img1", "mask": "b64img2"}, + {"image": "b64img1", "mask": "b64img2"}, + ] + ), + dict( + image=[ + {"image": "b64img1", "mask": None}, + {"image": "b64img1", "mask": "b64img2"}, + ] + ), + dict( + image=[ + {"image": "b64img1"}, + {"image": "b64img1", "mask": "b64img2"}, + ] + ), + dict(image="b64img1", mask="b64img2"), + dict(image="b64img1"), + dict(image="b64img1", mask_image="b64img2"), + dict(image=None), + # UI + dict(image=dict(image=img1)), + dict(image=dict(image=img1, mask=img2)), + # 2D mask should be accepted. + dict(image=dict(image=img1, mask=mask_2d)), + dict(image=img1, mask=mask_2d), + ], +) +def test_valid_image_formats(set_cls_funcs, d): + ControlNetUnit(**d) + unit = ControlNetUnit.from_dict(d) + unit.get_input_images_rgba() + + +@pytest.mark.parametrize( + "d", + [ + dict(image={"mask": "b64img1"}), + dict(image={"foo": "b64img1", "bar": "b64img2"}), + dict(image=["b64img1"]), + dict(image=("b64img1", "b64img2", "b64img1")), + dict(image=[]), + dict(image=[{"mask": "b64img1"}]), + dict(image=None, mask="b64img2"), + # image & mask have different H x W + dict(image="b64img1", mask="b64mask_diff"), + ], +) +def test_invalid_image_formats(set_cls_funcs, d): + # Setting field will be fine. + ControlNetUnit(**d) + unit = ControlNetUnit.from_dict(d) + # Error on eval. + with pytest.raises((ValueError, AssertionError)): + unit.get_input_images_rgba() + + +def test_mask_alias_conflict(): + with pytest.raises((ValueError, AssertionError)): + ControlNetUnit.from_dict( + dict( + image="b64img1", + mask="b64img1", + mask_image="b64img1", + ) + ), + + +def test_resize_mode(): + ControlNetUnit(resize_mode="Just Resize") + + +def test_weight(): + ControlNetUnit(weight=0.5) + ControlNetUnit(weight=0.0) + with pytest.raises(ValueError): + ControlNetUnit(weight=-1) + with pytest.raises(ValueError): + ControlNetUnit(weight=100) + + +def test_start_end(): + ControlNetUnit(guidance_start=0.0, guidance_end=1.0) + ControlNetUnit(guidance_start=0.5, guidance_end=1.0) + ControlNetUnit(guidance_start=0.5, guidance_end=0.5) + + with pytest.raises(ValueError): + ControlNetUnit(guidance_start=1.0, guidance_end=0.0) + with pytest.raises(ValueError): + ControlNetUnit(guidance_start=11) + with pytest.raises(ValueError): + ControlNetUnit(guidance_end=11) + + +def test_effective_region_mask(): + ControlNetUnit(effective_region_mask="b64img1") + ControlNetUnit(effective_region_mask=None) + ControlNetUnit(effective_region_mask=img1) + + with pytest.raises(ValueError): + ControlNetUnit(effective_region_mask=124) + + +def test_ipadapter_input(): + ControlNetUnit(ipadapter_input=["b64tensor1"]) + ControlNetUnit(ipadapter_input="b64tensor1") + ControlNetUnit(ipadapter_input=None) + + with pytest.raises(ValueError): + ControlNetUnit(ipadapter_input=[]) + + +@dataclass +class MockSlider: + value: float = 1 + minimum: float = 0 + maximum: float = 2 + + +@dataclass +class MockPreprocessor: + slider_resolution = MockSlider() + slider_1 = MockSlider() + slider_2 = MockSlider() + + +def test_preprocessor_sliders(): + unit = ControlNetUnit(enabled=True, module="none") + assert unit.processor_res == 1 + assert unit.threshold_a == 1 + assert unit.threshold_b == 1 + + +def test_preprocessor_sliders_disabled(): + unit = ControlNetUnit(enabled=False, module="none") + assert unit.processor_res == -1 + assert unit.threshold_a == -1 + assert unit.threshold_b == -1 + + +def test_infotext_parsing(): + infotext = ( + "Module: inpaint_only+lama, Model: control_v11p_sd15_inpaint [ebff9138], Weight: 1, " + "Resize Mode: Resize and Fill, Low Vram: False, Guidance Start: 0, Guidance End: 1, " + "Pixel Perfect: True, Control Mode: Balanced" + ) + assert ControlNetUnit( + enabled=True, + module="inpaint_only+lama", + model="control_v11p_sd15_inpaint [ebff9138]", + weight=1, + resize_mode="Resize and Fill", + low_vram=False, + guidance_start=0, + guidance_end=1, + pixel_perfect=True, + control_mode="Balanced", + ) == ControlNetUnit.parse(infotext)