Skip to content

Commit

Permalink
Merge pull request #469 from owenrao/main
Browse files Browse the repository at this point in the history
build: add pyracanny edge preprocessor just like canny
  • Loading branch information
Fannovel16 authored Oct 17, 2024
2 parents 302a389 + c44234e commit f5868ff
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
30 changes: 30 additions & 0 deletions node_wrappers/pyracanny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from ..utils import common_annotator_call, INPUT, define_preprocessor_inputs
import comfy.model_management as model_management

class PyraCanny_Preprocessor:
@classmethod
def INPUT_TYPES(s):
return define_preprocessor_inputs(
low_threshold=INPUT.INT(default=64, max=255),
high_threshold=INPUT.INT(default=128, max=255),
resolution=INPUT.RESOLUTION()
)

RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"

CATEGORY = "ControlNet Preprocessors/Line Extractors"

def execute(self, image, low_threshold=64, high_threshold=128, resolution=512, **kwargs):
from custom_controlnet_aux.pyracanny import PyraCannyDetector

return (common_annotator_call(PyraCannyDetector(), image, low_threshold=low_threshold, high_threshold=high_threshold, resolution=resolution), )



NODE_CLASS_MAPPINGS = {
"PyraCannyPreprocessor": PyraCanny_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PyraCannyPreprocessor": "PyraCanny"
}
74 changes: 74 additions & 0 deletions src/custom_controlnet_aux/pyracanny/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import warnings
import cv2
import numpy as np
from PIL import Image
from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3

def centered_canny(x: np.ndarray, canny_low_threshold, canny_high_threshold):
assert isinstance(x, np.ndarray)
assert x.ndim == 2 and x.dtype == np.uint8

y = cv2.Canny(x, int(canny_low_threshold), int(canny_high_threshold))
y = y.astype(np.float32) / 255.0
return y

def centered_canny_color(x: np.ndarray, canny_low_threshold, canny_high_threshold):
assert isinstance(x, np.ndarray)
assert x.ndim == 3 and x.shape[2] == 3

result = [centered_canny(x[..., i], canny_low_threshold, canny_high_threshold) for i in range(3)]
result = np.stack(result, axis=2)
return result

def pyramid_canny_color(x: np.ndarray, canny_low_threshold, canny_high_threshold):
assert isinstance(x, np.ndarray)
assert x.ndim == 3 and x.shape[2] == 3

H, W, C = x.shape
acc_edge = None

for k in [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
Hs, Ws = int(H * k), int(W * k)
small = cv2.resize(x, (Ws, Hs), interpolation=cv2.INTER_AREA)
edge = centered_canny_color(small, canny_low_threshold, canny_high_threshold)
if acc_edge is None:
acc_edge = edge
else:
acc_edge = cv2.resize(acc_edge, (edge.shape[1], edge.shape[0]), interpolation=cv2.INTER_LINEAR)
acc_edge = acc_edge * 0.75 + edge * 0.25

return acc_edge

def norm255(x, low=4, high=96):
assert isinstance(x, np.ndarray)
assert x.ndim == 2 and x.dtype == np.float32

v_min = np.percentile(x, low)
v_max = np.percentile(x, high)

x -= v_min
x /= v_max - v_min

return x * 255.0

def canny_pyramid(x, canny_low_threshold, canny_high_threshold):
# For some reasons, SAI's Control-lora Canny seems to be trained on canny maps with non-standard resolutions.
# Then we use pyramid to use all resolutions to avoid missing any structure in specific resolutions.

color_canny = pyramid_canny_color(x, canny_low_threshold, canny_high_threshold)
result = np.sum(color_canny, axis=2)

return norm255(result, low=1, high=99).clip(0, 255).astype(np.uint8)

class PyraCannyDetector:
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
detected_map = canny_pyramid(detected_map, low_threshold, high_threshold)
detected_map = HWC3(remove_pad(detected_map))

if output_type == "pil":
detected_map = Image.fromarray(detected_map)

return detected_map

0 comments on commit f5868ff

Please sign in to comment.