From bdf27d8ee410fd273bc9b66c6cb00bc9bf5da8cc Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Tue, 24 Dec 2024 11:05:45 +0800 Subject: [PATCH 01/13] seperate SegmentAnythingUltra V2 into nodes --- segment_anything.py | 708 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 708 insertions(+) diff --git a/segment_anything.py b/segment_anything.py index 6e478ac7..310c4fbf 100644 --- a/segment_anything.py +++ b/segment_anything.py @@ -1,12 +1,29 @@ +import copy +import glob import hashlib import json +import math import os from enum import Enum +from pathlib import Path +from typing import List, Union +from urllib.parse import urlparse +import comfy.model_management +import cv2 import folder_paths +import groundingdino.datasets.transforms as T import numpy as np import torch +from groundingdino.models import build_model as local_groundingdino_build_model +from groundingdino.util.inference import predict +from groundingdino.util.slconfig import SLConfig as local_groundingdino_SLConfig +from groundingdino.util.utils import ( + clean_state_dict as local_groundingdino_clean_state_dict, +) from PIL import Image, ImageOps, ImageSequence +from segment_anything_hq import SamPredictor, sam_model_registry +from torch.hub import download_url_to_file from bizyair.common.env_var import BIZYAIR_SERVER_ADDRESS from bizyair.image_utils import decode_base64_to_np, encode_image_to_base64 @@ -15,6 +32,53 @@ from .route_sam import SAM_COORDINATE from .utils import get_api_key, send_post_request +try: + from cv2.ximgproc import guidedFilter +except ImportError: + # print(e) + print( + f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'" + f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'." + f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m" + ) + +sam_model_dir_name = "sams" +sam_model_list = { + "sam_vit_h (2.56GB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + }, + "sam_vit_l (1.25GB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" + }, + "sam_vit_b (375MB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" + }, + "sam_hq_vit_h (2.57GB)": { + "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" + }, + "sam_hq_vit_l (1.25GB)": { + "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" + }, + "sam_hq_vit_b (379MB)": { + "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" + }, + "mobile_sam(39MB)": { + "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" + }, +} + +groundingdino_model_dir_name = "grounding-dino" +groundingdino_model_list = { + "GroundingDINO_SwinT_OGC (694MB)": { + "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", + "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", + }, + "GroundingDINO_SwinB (938MB)": { + "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", + "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", + }, +} + class INFER_MODE(Enum): auto = 0 @@ -28,6 +92,389 @@ class EDIT_MODE(Enum): point = 1 +def get_bert_base_uncased_model_path(): + comfy_bert_model_base = os.path.join(folder_paths.models_dir, "bert-base-uncased") + if glob.glob( + os.path.join(comfy_bert_model_base, "**/model.safetensors"), recursive=True + ): + print("grounding-dino is using models/bert-base-uncased") + return comfy_bert_model_base + return "bert-base-uncased" + + +def save_masks(outmasks, image): + if len(outmasks) == 0: + return + print("why masks: ", outmasks) + if len(outmasks.shape) > 3: + outmasks = outmasks.permute(1, 0, 2, 3) + outmasks = ( + outmasks.view(outmasks.shape[1], outmasks.shape[2], outmasks.shape[3]) + .cpu() + .numpy() + ) + + image_height, image_width, _ = image.shape + + img = np.zeros((image_height, image_width, 3), dtype=np.uint8) + mask_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) + + for mask in outmasks: + non_zero_positions = np.nonzero(mask) + # 使用非零值位置从原始图像中提取实例部分 + instance_part = image[non_zero_positions] + # 将提取的实例部分应用到合并图像中对应的位置 + img[non_zero_positions] = instance_part + + mask_image[non_zero_positions] = 255 # 标记为白色 + print("why ahahahah") + return img, mask_image + + +def list_sam_model(): + return list(sam_model_list.keys()) + + +def load_sam_model(model_name): + sam_checkpoint_path = get_local_filepath( + sam_model_list[model_name]["model_url"], sam_model_dir_name + ) + model_file_name = os.path.basename(sam_checkpoint_path) + model_type = model_file_name.split(".")[0] + if "hq" not in model_type and "mobile" not in model_type: + model_type = "_".join(model_type.split("_")[1:-1]) + print("why model_type: ", model_type) + print("path:", sam_checkpoint_path) + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path) + sam_device = comfy.model_management.get_torch_device() + sam.to(device=sam_device) + sam.eval() + sam.model_name = model_file_name + predictor = SamPredictor(sam) + return (sam, predictor) + + +def sam_predict( + predictor, image, input_points, input_label, input_boxes, multimask_output +): + predictor.set_image(image) + masks, scores, logits = predictor.predict( + point_coords=input_points, + point_labels=input_label, + box=input_boxes, + multimask_output=multimask_output, + ) + return masks, scores, logits + + +def sam_predict_torch( + predictor, image_np, input_points, input_boxes, input_label, multimask_output +): + + image_np_rgb = image_np[..., :3] + predictor.set_image(image_np_rgb) + + transformed_boxes = predictor.transform.apply_boxes_torch( + input_boxes, image_np.shape[:2] + ) + + masks, scores, logits = predictor.predict_torch( + point_coords=input_points, + point_labels=input_label, + boxes=transformed_boxes, + multimask_output=multimask_output, + ) + + return masks, scores, logits + + +def get_local_filepath(url, dirname, local_file_name=None): + if not local_file_name: + parsed_url = urlparse(url) + local_file_name = os.path.basename(parsed_url.path) + + destination = folder_paths.get_full_path(dirname, local_file_name) + if destination: + return destination + + folder = os.path.join(folder_paths.models_dir, dirname) + if not os.path.exists(folder): + os.makedirs(folder) + + destination = os.path.join(folder, local_file_name) + if not os.path.exists(destination): + download_url_to_file(url, destination) + return destination + + +def load_groundingdino_model(model_name): + dino_model_args = local_groundingdino_SLConfig.fromfile( + get_local_filepath( + groundingdino_model_list[model_name]["config_url"], + groundingdino_model_dir_name, + ), + ) + + if dino_model_args.text_encoder_type == "bert-base-uncased": + dino_model_args.text_encoder_type = get_bert_base_uncased_model_path() + + dino = local_groundingdino_build_model(dino_model_args) + checkpoint = torch.load( + get_local_filepath( + groundingdino_model_list[model_name]["model_url"], + groundingdino_model_dir_name, + ), + ) + dino.load_state_dict( + local_groundingdino_clean_state_dict(checkpoint["model"]), strict=False + ) + device = comfy.model_management.get_torch_device() + dino.to(device=device) + dino.eval() + return dino + + +def list_groundingdino_model(): + return list(groundingdino_model_list.keys()) + + +def load_image(image_pil): + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + image_transformed, _ = transform(image_pil, None) + return image_transformed + + +def guided_filter_alpha( + image: torch.Tensor, mask: torch.Tensor, filter_radius: int +) -> torch.Tensor: + sigma = 0.15 + d = filter_radius + 1 + mask = pil2tensor(tensor2pil(mask).convert("RGB")) + if not bool(d % 2): + d += 1 + s = sigma / 10 + i_dup = copy.deepcopy(image.cpu().numpy()) + a_dup = copy.deepcopy(mask.cpu().numpy()) + for index, image in enumerate(i_dup): + alpha_work = a_dup[index] + i_dup[index] = guidedFilter(image, alpha_work, d, s) + return torch.from_numpy(i_dup) + + +def histogram_remap( + image: torch.Tensor, blackpoint: float, whitepoint: float +) -> torch.Tensor: + bp = min(blackpoint, whitepoint - 0.001) + scale = 1 / (whitepoint - bp) + i_dup = copy.deepcopy(image.cpu().numpy()) + i_dup = np.clip((i_dup - bp) * scale, 0.0, 1.0) + return torch.from_numpy(i_dup) + + +def mask_edge_detail( + image: torch.Tensor, + mask: torch.Tensor, + detail_range: int = 8, + black_point: float = 0.01, + white_point: float = 0.99, +) -> torch.Tensor: + from pymatting import estimate_alpha_cf, fix_trimap + + d = detail_range * 5 + 1 + mask = pil2tensor(tensor2pil(mask).convert("RGB")) + if not bool(d % 2): + d += 1 + i_dup = copy.deepcopy(image.cpu().numpy().astype(np.float64)) + a_dup = copy.deepcopy(mask.cpu().numpy().astype(np.float64)) + for index, img in enumerate(i_dup): + trimap = a_dup[index][:, :, 0] # convert to single channel + if detail_range > 0: + trimap = cv2.GaussianBlur(trimap, (d, d), 0) + trimap = fix_trimap(trimap, black_point, white_point) + alpha = estimate_alpha_cf( + img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500} + ) + a_dup[index] = np.stack([alpha, alpha, alpha], axis=-1) # convert back to rgb + return torch.from_numpy(a_dup.astype(np.float32)) + + +def generate_VITMatte_trimap( + mask: torch.Tensor, erode_kernel_size: int, dilate_kernel_size: int +) -> Image: + def g_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10): + erode_kernel = np.ones((erode_kernel_size, erode_kernel_size), np.uint8) + dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8) + eroded = cv2.erode(mask, erode_kernel, iterations=5) + dilated = cv2.dilate(mask, dilate_kernel, iterations=5) + trimap = np.zeros_like(mask) + trimap[dilated == 255] = 128 + trimap[eroded == 255] = 255 + return trimap + + mask = mask.squeeze(0).cpu().detach().numpy().astype(np.uint8) * 255 + trimap = g_trimap(mask, erode_kernel_size, dilate_kernel_size).astype(np.float32) + trimap[trimap == 128] = 0.5 + trimap[trimap == 255] = 1 + trimap = torch.from_numpy(trimap).unsqueeze(0) + + return tensor2pil(trimap).convert("L") + + +def generate_VITMatte( + image: Image, + trimap: Image, + local_files_only: bool = False, + device: str = "cpu", + max_megapixels: float = 2.0, +) -> Image: + if image.mode != "RGB": + image = image.convert("RGB") + if trimap.mode != "L": + trimap = trimap.convert("L") + max_megapixels *= 1048576 + width, height = image.size + ratio = width / height + target_width = math.sqrt(ratio * max_megapixels) + target_height = target_width / ratio + target_width = int(target_width) + target_height = int(target_height) + if width * height > max_megapixels: + image = image.resize((target_width, target_height), Image.BILINEAR) + trimap = trimap.resize((target_width, target_height), Image.BILINEAR) + print( + f"vitmatte image size {width}x{height} too large, resize to {target_width}x{target_height} for processing." + ) + model_name = "hustvl/vitmatte-small-composition-1k" + if device == "cpu": + device = torch.device("cpu") + else: + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + print( + "vitmatte device is set to cuda, but not available, using cpu instead." + ) + device = torch.device("cpu") + vit_matte_model = load_VITMatte_model( + model_name=model_name, local_files_only=local_files_only + ) + vit_matte_model.model.to(device) + print( + f"vitmatte processing, image size = {image.width}x{image.height}, device = {device}." + ) + inputs = vit_matte_model.processor( + images=image, trimaps=trimap, return_tensors="pt" + ) + with torch.no_grad(): + inputs = {k: v.to(device) for k, v in inputs.items()} + predictions = vit_matte_model.model(**inputs).alphas + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + mask = tensor2pil(predictions).convert("L") + mask = mask.crop( + (0, 0, image.width, image.height) + ) # remove padding that the prediction appends (works in 32px tiles) + if width * height > max_megapixels: + mask = mask.resize((width, height), Image.BILINEAR) + return mask + + +class VITMatteModel: + def __init__(self, model, processor): + self.model = model + self.processor = processor + + +def load_VITMatte_model(model_name: str, local_files_only: bool = False) -> object: + # if local_files_only: + # model_name = Path(os.path.join(folder_paths.models_dir, "vitmatte")) + model_name = Path(os.path.join(folder_paths.models_dir, "vitmatte")) + from transformers import VitMatteForImageMatting, VitMatteImageProcessor + + model = VitMatteForImageMatting.from_pretrained( + model_name, local_files_only=local_files_only + ) + processor = VitMatteImageProcessor.from_pretrained( + model_name, local_files_only=local_files_only + ) + vitmatte = VITMatteModel(model, processor) + return vitmatte + + +def pil2tensor(image: Image) -> torch.Tensor: + return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) + + +def tensor2pil(t_image: torch.Tensor) -> Image: + return Image.fromarray( + np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + ) + + +def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]: + if len(tensor.shape) == 3: # Single image + return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8) + else: # Batch of images + return [ + np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor + ] + + +def mask2image(mask: torch.Tensor) -> Image: + masks = tensor2np(mask) + for m in masks: + _mask = Image.fromarray(m).convert("L") + _image = Image.new("RGBA", _mask.size, color="white") + _image = Image.composite( + _image, Image.new("RGBA", _mask.size, color="black"), _mask + ) + return _image + + +def image2mask(image: Image) -> torch.Tensor: + _image = image.convert("RGBA") + alpha = _image.split()[0] + bg = Image.new("L", _image.size) + _image = Image.merge("RGBA", (bg, bg, bg, alpha)) + ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()]) + return ret_mask + + +def RGB2RGBA(image: Image, mask: Image) -> Image: + (R, G, B) = image.convert("RGB").split() + return Image.merge("RGBA", (R, G, B, mask.convert("L"))) + + +def groundingdino_predict(dino_model, image_pil, prompt, box_threshold, text_threshold): + image = load_image(image_pil) + + boxes, logits, phrases = predict( + model=dino_model, + image=image, + caption=prompt, + box_threshold=box_threshold, + text_threshold=text_threshold, + ) + + filt_mask = logits > box_threshold + boxes_filt = boxes.clone() + boxes_filt = boxes_filt[filt_mask] + H, W = image_pil.size[1], image_pil.size[0] + for i in range(boxes_filt.size(0)): + boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) + boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 + boxes_filt[i][2:] += boxes_filt[i][:2] + return boxes_filt + + class BizyAirSegmentAnythingText: API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/sam" @@ -267,11 +714,272 @@ def VALIDATE_INPUTS(s, image, is_point): return True +class BizyAirSAMModelLoader: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_name": (list_sam_model(),), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ( + "SAM_MODEL", + "SAM_PREDICTOR", + ) + + def main(self, model_name): + sam_model, sam_predictor = load_sam_model(model_name) + return ( + sam_model, + sam_predictor, + ) + + +class BizyAirGroundingDinoModelLoader: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_name": (list_groundingdino_model(),), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ("GROUNDING_DINO_MODEL",) + + def main(self, model_name): + dino_model = load_groundingdino_model(model_name) + return (dino_model,) + + +class BizyAirGroundingDinoSAMSegment: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "sam_model": ("SAM_MODEL", {}), + "grounding_dino_model": ("GROUNDING_DINO_MODEL", {}), + "sam_predictor": ("SAM_PREDICTOR", {}), + "image": ("IMAGE", {}), + "prompt": ("STRING", {}), + "box_threshold": ( + "FLOAT", + {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, + ), + "text_threshold": ( + "FLOAT", + {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ("IMAGE", "MASK") + + def main( + self, + grounding_dino_model, + sam_model, + sam_predictor, + image, + prompt, + box_threshold, + text_threshold, + ): + res_images = [] + res_masks = [] + multimask_output = False + for item in image: + item = Image.fromarray( + # np.clip(255. * item.cpu().numpy(), 0, 255).astype(np.uint8)).convert('RGBA') + np.clip(255.0 * item.cpu().numpy(), 0, 255).astype(np.uint8) + ) + img = np.array(item) + + boxes = groundingdino_predict( + grounding_dino_model, item, prompt, box_threshold, text_threshold + ) + + if boxes.shape[0] == 0: + break + sam_device = comfy.model_management.get_torch_device() + boxes = boxes.to(sam_device) + masks, scores, logits = sam_predict_torch( + sam_predictor, + img, + None, + boxes, + None, + multimask_output, + ) + outimage, mask_image = save_masks(masks, img) + print("why image.type", type(outimage)) + print("why mask_image.type", type(mask_image)) + images = (torch.from_numpy(outimage).float() / 255.0).unsqueeze(0) + masks = (torch.from_numpy(mask_image).float() / 255.0).unsqueeze(0) + print("1111") + print("why images.shape: ", images.shape) + res_images.append(images) + res_masks.append(masks) + if len(res_images) > 1: + output_image = torch.cat(res_images, dim=0) + output_mask = torch.cat(res_masks, dim=0) + else: + output_image = res_images[0] + output_mask = res_masks[0] + print("why outPUt: ", output_image.shape) + print("why outPUt: ", output_mask.shape) + return (output_image, output_mask) + # return (image, torch.cat(res_masks, dim=0)) + + +class BizyAirTrimapGenerate: + @classmethod + def INPUT_TYPES(cls): + method_list = [ + "VITMatte", + "VITMatte(local)", + "PyMatting", + "GuidedFilter", + ] + return { + "required": { + "image": ("IMAGE", {}), + "mask": ("MASK",), + "detail_method": (method_list,), + "detail_erode": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "detail_dilate": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "black_point": ( + "FLOAT", + { + "default": 0.15, + "min": 0.01, + "max": 0.98, + "step": 0.01, + "display": "slider", + }, + ), + "white_point": ( + "FLOAT", + { + "default": 0.99, + "min": 0.02, + "max": 0.99, + "step": 0.01, + "display": "slider", + }, + ), + "max_megapixels": ( + "FLOAT", + {"default": 2.0, "min": 1, "max": 999, "step": 0.1}, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ( + "IMAGE", + "MASK", + ) + RETURN_NAMES = ( + "image", + "mask", + ) + + def main( + self, + image, + mask, + detail_method, + detail_erode, + detail_dilate, + black_point, + white_point, + max_megapixels, + ): + if detail_method == "VITMatte(local)": + local_files_only = True + else: + local_files_only = False + + ret_images = [] + ret_masks = [] + device = comfy.model_management.get_torch_device() + print("image.shape:", image.shape) + print("image.shape[0]", image.shape[0]) + for i in range(image.shape[0]): + img = torch.unsqueeze(image[i], 0) + img = pil2tensor(tensor2pil(img).convert("RGB")) + _image = tensor2pil(img).convert("RGBA") + + detail_range = detail_erode + detail_dilate + if detail_method == "GuidedFilter": + _mask = guided_filter_alpha(img, mask[i], detail_range // 6 + 1) + _mask = tensor2pil(histogram_remap(_mask, black_point, white_point)) + elif detail_method == "PyMatting": + _mask = tensor2pil( + mask_edge_detail( + img, mask[i], detail_range // 8 + 1, black_point, white_point + ) + ) + else: + print("why trimap") + _trimap = generate_VITMatte_trimap(mask[i], detail_erode, detail_dilate) + _mask = generate_VITMatte( + _image, + _trimap, + local_files_only=local_files_only, + device=device, + max_megapixels=max_megapixels, + ) + _mask = tensor2pil( + histogram_remap(pil2tensor(_mask), black_point, white_point) + ) + + # _mask = mask2image(_mask) + + _image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L")) + + ret_images.append(pil2tensor(_image)) + ret_masks.append(image2mask(_mask)) + if len(ret_masks) == 0: + _, height, width, _ = image.size() + empty_mask = torch.zeros( + (1, height, width), dtype=torch.uint8, device="cpu" + ) + return (empty_mask, empty_mask) + + return ( + torch.cat(ret_images, dim=0), + torch.cat(ret_masks, dim=0), + ) + + NODE_CLASS_MAPPINGS = { "BizyAirSegmentAnythingText": BizyAirSegmentAnythingText, "BizyAirSegmentAnythingPointBox": BizyAirSegmentAnythingPointBox, + "BizyAirGroundingDinoModelLoader": BizyAirGroundingDinoModelLoader, + "BizyAirSAMModelLoader": BizyAirSAMModelLoader, + "BizyAirGroundingDinoSAMSegment": BizyAirGroundingDinoSAMSegment, + "BizyAirTrimapGenerate": BizyAirTrimapGenerate, } NODE_DISPLAY_NAME_MAPPINGS = { "BizyAirSegmentAnythingText": "☁️BizyAir Text Guided SAM", "BizyAirSegmentAnythingPointBox": "☁️BizyAir Point-Box Guided SAM", + "BizyAirGroundingDinoModelLoader": "☁️BizyAir Load GroundingDino Model", + "BizyAirSAMModelLoader": "☁️BizyAir Load SAM Model", + "BizyAirGroundingDinoSAMSegment": "☁️BizyAir GroundingDinoSAMSegment", + "BizyAirTrimapGenerate": "☁️BizyAir Trimap Generate", } From 13f4e0193f05cf5d073bbb9d9a0fd90af9452b2b Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Tue, 24 Dec 2024 17:04:13 +0800 Subject: [PATCH 02/13] refine the code --- __init__.py | 2 + sam.py | 457 ++++++++++++++++++++++++++++ sam_func.py | 381 ++++++++++++++++++++++++ segment_anything.py | 708 -------------------------------------------- 4 files changed, 840 insertions(+), 708 deletions(-) create mode 100644 sam.py create mode 100644 sam_func.py diff --git a/__init__.py b/__init__.py index eeab0941..acb20261 100644 --- a/__init__.py +++ b/__init__.py @@ -18,6 +18,7 @@ nodes, nodes_controlnet_aux, nodes_controlnet_union_sdxl, + sam, segment_anything, showcase, supernode, @@ -36,6 +37,7 @@ def update_mappings(module): update_mappings(nodes_controlnet_union_sdxl) update_mappings(mzkolors) update_mappings(segment_anything) +update_mappings(sam) try: import bizy_server diff --git a/sam.py b/sam.py new file mode 100644 index 00000000..80911f35 --- /dev/null +++ b/sam.py @@ -0,0 +1,457 @@ +import os +from pathlib import Path +from urllib.parse import urlparse + +import comfy.model_management +import folder_paths +import groundingdino.datasets.transforms as T +import numpy as np +import torch +from PIL import Image +from segment_anything_hq import SamPredictor, sam_model_registry + +from .sam_func import * + + +class BizyAirSAMModelLoader: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_name": (list_sam_model(),), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ("SAM_PREDICTOR",) + + def main(self, model_name): + sam_checkpoint_path = get_local_filepath( + sam_model_list[model_name]["model_url"], sam_model_dir_name + ) + model_file_name = os.path.basename(sam_checkpoint_path) + model_type = model_file_name.split(".")[0] + if "hq" not in model_type and "mobile" not in model_type: + model_type = "_".join(model_type.split("_")[1:-1]) + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path) + sam_device = comfy.model_management.get_torch_device() + sam.to(device=sam_device) + sam.eval() + sam.model_name = model_file_name + predictor = SamPredictor(sam) + + return (predictor,) + + +class BizyAirGroundingDinoModelLoader: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_name": (list_groundingdino_model(),), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ("GROUNDING_DINO_MODEL",) + + def main(self, model_name): + dino_model = load_groundingdino_model(model_name) + return (dino_model,) + + +class BizyAirVITMatteModelLoader: + @classmethod + def INPUT_TYPES(cls): + method_list = [ + "VITMatte", + "VITMatte(local)", + ] + return { + "required": { + "detail_method": (method_list,), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ( + "VitMatte_MODEL", + "VitMatte_predictor", + ) + + def main(self, detail_method): + if detail_method == "VITMatte(local)": + local_files_only = True + else: + local_files_only = False + + model_name = Path(os.path.join(folder_paths.models_dir, "vitmatte")) + from transformers import VitMatteForImageMatting, VitMatteImageProcessor + + device = comfy.model_management.get_torch_device() + + model = VitMatteForImageMatting.from_pretrained( + model_name, local_files_only=local_files_only + ) + processor = VitMatteImageProcessor.from_pretrained( + model_name, local_files_only=local_files_only + ) + model.to(device) + return (model, processor) + + +class BizyAirGroundingDinoSAMSegment: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "grounding_dino_model": ("GROUNDING_DINO_MODEL", {}), + "sam_predictor": ("SAM_PREDICTOR", {}), + "image": ("IMAGE", {}), + "prompt": ("STRING", {}), + "box_threshold": ( + "FLOAT", + {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, + ), + "text_threshold": ( + "FLOAT", + {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ("IMAGE", "MASK") + + def main( + self, + grounding_dino_model, + sam_predictor, + image, + prompt, + box_threshold, + text_threshold, + ): + res_images = [] + res_masks = [] + multimask_output = False + for item in image: + item = Image.fromarray( + # np.clip(255. * item.cpu().numpy(), 0, 255).astype(np.uint8)).convert('RGBA') + np.clip(255.0 * item.cpu().numpy(), 0, 255).astype(np.uint8) + ) + img = np.array(item) + + boxes = groundingdino_predict( + grounding_dino_model, item, prompt, box_threshold, text_threshold + ) + + if boxes.shape[0] == 0: + break + sam_device = comfy.model_management.get_torch_device() + boxes = boxes.to(sam_device) + masks, scores, logits = sam_predict_torch( + sam_predictor, + img, + None, + boxes, + None, + multimask_output, + ) + outimage, mask_image = save_masks(masks, img) + images = (torch.from_numpy(outimage).float() / 255.0).unsqueeze(0) + masks = (torch.from_numpy(mask_image).float() / 255.0).unsqueeze(0) + res_images.append(images) + res_masks.append(masks) + if len(res_images) > 1: + output_image = torch.cat(res_images, dim=0) + output_mask = torch.cat(res_masks, dim=0) + else: + output_image = res_images[0] + output_mask = res_masks[0] + return (output_image, output_mask) + # return (image, torch.cat(res_masks, dim=0)) + + +class BizyAirTrimapGenerate1: + @classmethod + def INPUT_TYPES(cls): + method_list = [ + "VITMatte", + "VITMatte(local)", + "PyMatting", + "GuidedFilter", + ] + return { + "required": { + "image": ("IMAGE", {}), + "mask": ("MASK",), + "detail_method": (method_list,), + "detail_erode": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "detail_dilate": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "black_point": ( + "FLOAT", + { + "default": 0.15, + "min": 0.01, + "max": 0.98, + "step": 0.01, + "display": "slider", + }, + ), + "white_point": ( + "FLOAT", + { + "default": 0.99, + "min": 0.02, + "max": 0.99, + "step": 0.01, + "display": "slider", + }, + ), + "max_megapixels": ( + "FLOAT", + {"default": 2.0, "min": 1, "max": 999, "step": 0.1}, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ( + "IMAGE", + "MASK", + "MASK", + ) + RETURN_NAMES = ("image", "mask", "trimap") + + def main( + self, + image, + mask, + trimap, + detail_method, + detail_erode, + detail_dilate, + black_point, + white_point, + max_megapixels, + ): + if detail_method == "VITMatte(local)": + local_files_only = True + else: + local_files_only = False + + ret_images = [] + ret_masks = [] + device = comfy.model_management.get_torch_device() + print("image.shape:", image.shape) + print("image.shape[0]", image.shape[0]) + for i in range(image.shape[0]): + img = torch.unsqueeze(image[i], 0) + img = pil2tensor(tensor2pil(img).convert("RGB")) + _image = tensor2pil(img).convert("RGBA") + + detail_range = detail_erode + detail_dilate + if detail_method == "GuidedFilter": + _mask = guided_filter_alpha(img, mask[i], detail_range // 6 + 1) + _mask = tensor2pil(histogram_remap(_mask, black_point, white_point)) + elif detail_method == "PyMatting": + _mask = tensor2pil( + mask_edge_detail( + img, mask[i], detail_range // 8 + 1, black_point, white_point + ) + ) + else: + _trimap = generate_VITMatte_trimap(mask[i], detail_erode, detail_dilate) + _mask = generate_VITMatte( + _image, + _trimap, + local_files_only=local_files_only, + device=device, + max_megapixels=max_megapixels, + ) + _mask = tensor2pil( + histogram_remap(pil2tensor(_mask), black_point, white_point) + ) + + # _mask = mask2image(_mask) + + _image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L")) + + ret_images.append(pil2tensor(_image)) + ret_masks.append(image2mask(_mask)) + if len(ret_masks) == 0: + _, height, width, _ = image.size() + empty_mask = torch.zeros( + (1, height, width), dtype=torch.uint8, device="cpu" + ) + return (empty_mask, empty_mask) + + return ( + torch.cat(ret_images, dim=0), + torch.cat(ret_masks, dim=0), + pil2tensor(_trimap), + ) + + +class BizyAirTrimapGenerate: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "detail_erode": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "detail_dilate": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("trimap",) + + def main( + self, + mask, + detail_erode, + detail_dilate, + ): + + ret_masks = [] + + for i in range(mask.shape[0]): + _trimap = generate_VITMatte_trimap(mask[i], detail_erode, detail_dilate) + _trimap_tensor = pil2tensor(_trimap) + ret_masks.append(_trimap_tensor) + + return (torch.cat(ret_masks, dim=0),) + + +class BizyAirVITMattePredict: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE", {}), + "trimap": ("MASK",), + "vitmatte_model": ("VitMatte_MODEL", {}), + "vitmatte_predictor": ("VitMatte_predictor", {}), + "black_point": ( + "FLOAT", + { + "default": 0.15, + "min": 0.01, + "max": 0.98, + "step": 0.01, + "display": "slider", + }, + ), + "white_point": ( + "FLOAT", + { + "default": 0.99, + "min": 0.02, + "max": 0.99, + "step": 0.01, + "display": "slider", + }, + ), + "max_megapixels": ( + "FLOAT", + {"default": 2.0, "min": 1, "max": 999, "step": 0.1}, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + FUNCTION = "main" + RETURN_TYPES = ( + "IMAGE", + "MASK", + ) + RETURN_NAMES = ( + "image", + "mask", + ) + + def main( + self, + image, + trimap, + vitmatte_model, + vitmatte_predictor, + black_point, + white_point, + max_megapixels, + ): + + ret_images = [] + ret_masks = [] + device = comfy.model_management.get_torch_device() + + for i in range(image.shape[0]): + img = torch.unsqueeze(image[i], 0) + img = pil2tensor(tensor2pil(img).convert("RGB")) + _image = tensor2pil(img).convert("RGBA") + _mask = generate_VITMatte( + vitmatte_model, + vitmatte_predictor, + _image, + tensor2pil(trimap[i]), + device=device, + max_megapixels=max_megapixels, + ) + _mask = tensor2pil( + histogram_remap(pil2tensor(_mask), black_point, white_point) + ) + + _image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L")) + + ret_images.append(pil2tensor(_image)) + ret_masks.append(image2mask(_mask)) + if len(ret_masks) == 0: + _, height, width, _ = image.size() + empty_mask = torch.zeros( + (1, height, width), dtype=torch.uint8, device="cpu" + ) + return (empty_mask, empty_mask) + + return ( + torch.cat(ret_images, dim=0), + torch.cat(ret_masks, dim=0), + ) + + +NODE_CLASS_MAPPINGS = { + "BizyAirGroundingDinoModelLoader": BizyAirGroundingDinoModelLoader, + "BizyAirSAMModelLoader": BizyAirSAMModelLoader, + "BizyAirVITMatteModelLoader": BizyAirVITMatteModelLoader, + "BizyAirGroundingDinoSAMSegment": BizyAirGroundingDinoSAMSegment, + "BizyAirTrimapGenerate": BizyAirTrimapGenerate, + "BizyAirVITMattePredict": BizyAirVITMattePredict, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "BizyAirGroundingDinoModelLoader": "☁️BizyAir Load GroundingDino Model", + "BizyAirSAMModelLoader": "☁️BizyAir Load SAM Model", + "BizyAirVITMatteModelLoader": "☁️BizyAir Load VITMatte Model", + "BizyAirGroundingDinoSAMSegment": "☁️BizyAir GroundingDinoSAMSegment", + "BizyAirTrimapGenerate": "☁️BizyAir Trimap Generate", + "BizyAirVITMattePredict": "☁️BizyAir VITMatte Predict", +} diff --git a/sam_func.py b/sam_func.py new file mode 100644 index 00000000..883419d5 --- /dev/null +++ b/sam_func.py @@ -0,0 +1,381 @@ +import copy +import glob +import math +import os +from typing import List +from urllib.parse import urlparse + +import comfy.model_management +import cv2 +import folder_paths +import groundingdino.datasets.transforms as T +import numpy as np +import torch +from groundingdino.models import build_model as local_groundingdino_build_model +from groundingdino.util.inference import predict +from groundingdino.util.slconfig import SLConfig as local_groundingdino_SLConfig +from groundingdino.util.utils import ( + clean_state_dict as local_groundingdino_clean_state_dict, +) +from PIL import Image +from torch.hub import download_url_to_file + +try: + from cv2.ximgproc import guidedFilter +except ImportError: + # print(e) + print( + f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'" + f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'." + f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m" + ) + + +sam_model_dir_name = "sams" +sam_model_list = { + "sam_vit_h (2.56GB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + }, + "sam_vit_l (1.25GB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" + }, + "sam_vit_b (375MB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" + }, + "sam_hq_vit_h (2.57GB)": { + "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" + }, + "sam_hq_vit_l (1.25GB)": { + "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" + }, + "sam_hq_vit_b (379MB)": { + "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" + }, + "mobile_sam(39MB)": { + "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" + }, +} + +groundingdino_model_dir_name = "grounding-dino" +groundingdino_model_list = { + "GroundingDINO_SwinT_OGC (694MB)": { + "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", + "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", + }, + "GroundingDINO_SwinB (938MB)": { + "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", + "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", + }, +} + + +def get_bert_base_uncased_model_path(): + comfy_bert_model_base = os.path.join(folder_paths.models_dir, "bert-base-uncased") + if glob.glob( + os.path.join(comfy_bert_model_base, "**/model.safetensors"), recursive=True + ): + print("grounding-dino is using models/bert-base-uncased") + return comfy_bert_model_base + return "bert-base-uncased" + + +def save_masks(outmasks, image): + if len(outmasks) == 0: + return + if len(outmasks.shape) > 3: + outmasks = outmasks.permute(1, 0, 2, 3) + outmasks = ( + outmasks.view(outmasks.shape[1], outmasks.shape[2], outmasks.shape[3]) + .cpu() + .numpy() + ) + + image_height, image_width, _ = image.shape + + img = np.zeros((image_height, image_width, 3), dtype=np.uint8) + mask_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) + + for mask in outmasks: + non_zero_positions = np.nonzero(mask) + # 使用非零值位置从原始图像中提取实例部分 + instance_part = image[non_zero_positions] + # 将提取的实例部分应用到合并图像中对应的位置 + img[non_zero_positions] = instance_part + + mask_image[non_zero_positions] = 255 # 标记为白色 + return img, mask_image + + +def list_sam_model(): + return list(sam_model_list.keys()) + + +def sam_predict_torch( + predictor, image_np, input_points, input_boxes, input_label, multimask_output +): + + image_np_rgb = image_np[..., :3] + predictor.set_image(image_np_rgb) + + transformed_boxes = predictor.transform.apply_boxes_torch( + input_boxes, image_np.shape[:2] + ) + + masks, scores, logits = predictor.predict_torch( + point_coords=input_points, + point_labels=input_label, + boxes=transformed_boxes, + multimask_output=multimask_output, + ) + + return masks, scores, logits + + +def get_local_filepath(url, dirname, local_file_name=None): + if not local_file_name: + parsed_url = urlparse(url) + local_file_name = os.path.basename(parsed_url.path) + + destination = folder_paths.get_full_path(dirname, local_file_name) + if destination: + return destination + + folder = os.path.join(folder_paths.models_dir, dirname) + if not os.path.exists(folder): + os.makedirs(folder) + + destination = os.path.join(folder, local_file_name) + if not os.path.exists(destination): + download_url_to_file(url, destination) + return destination + + +def load_groundingdino_model(model_name): + dino_model_args = local_groundingdino_SLConfig.fromfile( + get_local_filepath( + groundingdino_model_list[model_name]["config_url"], + groundingdino_model_dir_name, + ), + ) + + if dino_model_args.text_encoder_type == "bert-base-uncased": + dino_model_args.text_encoder_type = get_bert_base_uncased_model_path() + + dino = local_groundingdino_build_model(dino_model_args) + checkpoint = torch.load( + get_local_filepath( + groundingdino_model_list[model_name]["model_url"], + groundingdino_model_dir_name, + ), + ) + dino.load_state_dict( + local_groundingdino_clean_state_dict(checkpoint["model"]), strict=False + ) + device = comfy.model_management.get_torch_device() + dino.to(device=device) + dino.eval() + return dino + + +def list_groundingdino_model(): + return list(groundingdino_model_list.keys()) + + +def load_image(image_pil): + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + image_transformed, _ = transform(image_pil, None) + return image_transformed + + +def guided_filter_alpha( + image: torch.Tensor, mask: torch.Tensor, filter_radius: int +) -> torch.Tensor: + sigma = 0.15 + d = filter_radius + 1 + mask = pil2tensor(tensor2pil(mask).convert("RGB")) + if not bool(d % 2): + d += 1 + s = sigma / 10 + i_dup = copy.deepcopy(image.cpu().numpy()) + a_dup = copy.deepcopy(mask.cpu().numpy()) + for index, image in enumerate(i_dup): + alpha_work = a_dup[index] + i_dup[index] = guidedFilter(image, alpha_work, d, s) + return torch.from_numpy(i_dup) + + +def histogram_remap( + image: torch.Tensor, blackpoint: float, whitepoint: float +) -> torch.Tensor: + bp = min(blackpoint, whitepoint - 0.001) + scale = 1 / (whitepoint - bp) + i_dup = copy.deepcopy(image.cpu().numpy()) + i_dup = np.clip((i_dup - bp) * scale, 0.0, 1.0) + return torch.from_numpy(i_dup) + + +def mask_edge_detail( + image: torch.Tensor, + mask: torch.Tensor, + detail_range: int = 8, + black_point: float = 0.01, + white_point: float = 0.99, +) -> torch.Tensor: + from pymatting import estimate_alpha_cf, fix_trimap + + d = detail_range * 5 + 1 + mask = pil2tensor(tensor2pil(mask).convert("RGB")) + if not bool(d % 2): + d += 1 + i_dup = copy.deepcopy(image.cpu().numpy().astype(np.float64)) + a_dup = copy.deepcopy(mask.cpu().numpy().astype(np.float64)) + for index, img in enumerate(i_dup): + trimap = a_dup[index][:, :, 0] # convert to single channel + if detail_range > 0: + trimap = cv2.GaussianBlur(trimap, (d, d), 0) + trimap = fix_trimap(trimap, black_point, white_point) + alpha = estimate_alpha_cf( + img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500} + ) + a_dup[index] = np.stack([alpha, alpha, alpha], axis=-1) # convert back to rgb + return torch.from_numpy(a_dup.astype(np.float32)) + + +def generate_VITMatte_trimap( + mask: torch.Tensor, erode_kernel_size: int, dilate_kernel_size: int +) -> Image: + def g_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10): + erode_kernel = np.ones((erode_kernel_size, erode_kernel_size), np.uint8) + dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8) + eroded = cv2.erode(mask, erode_kernel, iterations=5) + dilated = cv2.dilate(mask, dilate_kernel, iterations=5) + trimap = np.zeros_like(mask) + trimap[dilated == 255] = 128 + trimap[eroded == 255] = 255 + return trimap + + mask = mask.squeeze(0).cpu().detach().numpy().astype(np.uint8) * 255 + trimap = g_trimap(mask, erode_kernel_size, dilate_kernel_size).astype(np.float32) + trimap[trimap == 128] = 0.5 + trimap[trimap == 255] = 1 + trimap = torch.from_numpy(trimap).unsqueeze(0) + + return tensor2pil(trimap).convert("L") + + +def generate_VITMatte( + vit_matte_model, + vitmatte_predictor, + image: Image, + trimap: Image, + device: str = "cpu", + max_megapixels: float = 2.0, +) -> Image: + if image.mode != "RGB": + image = image.convert("RGB") + if trimap.mode != "L": + trimap = trimap.convert("L") + max_megapixels *= 1048576 + width, height = image.size + ratio = width / height + target_width = math.sqrt(ratio * max_megapixels) + target_height = target_width / ratio + target_width = int(target_width) + target_height = int(target_height) + if width * height > max_megapixels: + image = image.resize((target_width, target_height), Image.BILINEAR) + trimap = trimap.resize((target_width, target_height), Image.BILINEAR) + print( + f"vitmatte image size {width}x{height} too large, resize to {target_width}x{target_height} for processing." + ) + + print( + f"vitmatte processing, image size = {image.width}x{image.height}, device = {device}." + ) + inputs = vitmatte_predictor(images=image, trimaps=trimap, return_tensors="pt") + with torch.no_grad(): + inputs = {k: v.to(device) for k, v in inputs.items()} + predictions = vit_matte_model(**inputs).alphas + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + mask = tensor2pil(predictions).convert("L") + mask = mask.crop( + (0, 0, image.width, image.height) + ) # remove padding that the prediction appends (works in 32px tiles) + if width * height > max_megapixels: + mask = mask.resize((width, height), Image.BILINEAR) + return mask + + +def pil2tensor(image: Image) -> torch.Tensor: + return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) + + +def tensor2pil(t_image: torch.Tensor) -> Image: + return Image.fromarray( + np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + ) + + +def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]: + if len(tensor.shape) == 3: # Single image + return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8) + else: # Batch of images + return [ + np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor + ] + + +def mask2image(mask: torch.Tensor) -> Image: + masks = tensor2np(mask) + for m in masks: + _mask = Image.fromarray(m).convert("L") + _image = Image.new("RGBA", _mask.size, color="white") + _image = Image.composite( + _image, Image.new("RGBA", _mask.size, color="black"), _mask + ) + return _image + + +def image2mask(image: Image) -> torch.Tensor: + _image = image.convert("RGBA") + alpha = _image.split()[0] + bg = Image.new("L", _image.size) + _image = Image.merge("RGBA", (bg, bg, bg, alpha)) + ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()]) + return ret_mask + + +def RGB2RGBA(image: Image, mask: Image) -> Image: + (R, G, B) = image.convert("RGB").split() + return Image.merge("RGBA", (R, G, B, mask.convert("L"))) + + +def groundingdino_predict(dino_model, image_pil, prompt, box_threshold, text_threshold): + image = load_image(image_pil) + + boxes, logits, phrases = predict( + model=dino_model, + image=image, + caption=prompt, + box_threshold=box_threshold, + text_threshold=text_threshold, + ) + + filt_mask = logits > box_threshold + boxes_filt = boxes.clone() + boxes_filt = boxes_filt[filt_mask] + H, W = image_pil.size[1], image_pil.size[0] + for i in range(boxes_filt.size(0)): + boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) + boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 + boxes_filt[i][2:] += boxes_filt[i][:2] + return boxes_filt diff --git a/segment_anything.py b/segment_anything.py index 310c4fbf..6e478ac7 100644 --- a/segment_anything.py +++ b/segment_anything.py @@ -1,29 +1,12 @@ -import copy -import glob import hashlib import json -import math import os from enum import Enum -from pathlib import Path -from typing import List, Union -from urllib.parse import urlparse -import comfy.model_management -import cv2 import folder_paths -import groundingdino.datasets.transforms as T import numpy as np import torch -from groundingdino.models import build_model as local_groundingdino_build_model -from groundingdino.util.inference import predict -from groundingdino.util.slconfig import SLConfig as local_groundingdino_SLConfig -from groundingdino.util.utils import ( - clean_state_dict as local_groundingdino_clean_state_dict, -) from PIL import Image, ImageOps, ImageSequence -from segment_anything_hq import SamPredictor, sam_model_registry -from torch.hub import download_url_to_file from bizyair.common.env_var import BIZYAIR_SERVER_ADDRESS from bizyair.image_utils import decode_base64_to_np, encode_image_to_base64 @@ -32,53 +15,6 @@ from .route_sam import SAM_COORDINATE from .utils import get_api_key, send_post_request -try: - from cv2.ximgproc import guidedFilter -except ImportError: - # print(e) - print( - f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'" - f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'." - f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m" - ) - -sam_model_dir_name = "sams" -sam_model_list = { - "sam_vit_h (2.56GB)": { - "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" - }, - "sam_vit_l (1.25GB)": { - "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" - }, - "sam_vit_b (375MB)": { - "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" - }, - "sam_hq_vit_h (2.57GB)": { - "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" - }, - "sam_hq_vit_l (1.25GB)": { - "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" - }, - "sam_hq_vit_b (379MB)": { - "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" - }, - "mobile_sam(39MB)": { - "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" - }, -} - -groundingdino_model_dir_name = "grounding-dino" -groundingdino_model_list = { - "GroundingDINO_SwinT_OGC (694MB)": { - "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", - "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", - }, - "GroundingDINO_SwinB (938MB)": { - "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", - "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", - }, -} - class INFER_MODE(Enum): auto = 0 @@ -92,389 +28,6 @@ class EDIT_MODE(Enum): point = 1 -def get_bert_base_uncased_model_path(): - comfy_bert_model_base = os.path.join(folder_paths.models_dir, "bert-base-uncased") - if glob.glob( - os.path.join(comfy_bert_model_base, "**/model.safetensors"), recursive=True - ): - print("grounding-dino is using models/bert-base-uncased") - return comfy_bert_model_base - return "bert-base-uncased" - - -def save_masks(outmasks, image): - if len(outmasks) == 0: - return - print("why masks: ", outmasks) - if len(outmasks.shape) > 3: - outmasks = outmasks.permute(1, 0, 2, 3) - outmasks = ( - outmasks.view(outmasks.shape[1], outmasks.shape[2], outmasks.shape[3]) - .cpu() - .numpy() - ) - - image_height, image_width, _ = image.shape - - img = np.zeros((image_height, image_width, 3), dtype=np.uint8) - mask_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - - for mask in outmasks: - non_zero_positions = np.nonzero(mask) - # 使用非零值位置从原始图像中提取实例部分 - instance_part = image[non_zero_positions] - # 将提取的实例部分应用到合并图像中对应的位置 - img[non_zero_positions] = instance_part - - mask_image[non_zero_positions] = 255 # 标记为白色 - print("why ahahahah") - return img, mask_image - - -def list_sam_model(): - return list(sam_model_list.keys()) - - -def load_sam_model(model_name): - sam_checkpoint_path = get_local_filepath( - sam_model_list[model_name]["model_url"], sam_model_dir_name - ) - model_file_name = os.path.basename(sam_checkpoint_path) - model_type = model_file_name.split(".")[0] - if "hq" not in model_type and "mobile" not in model_type: - model_type = "_".join(model_type.split("_")[1:-1]) - print("why model_type: ", model_type) - print("path:", sam_checkpoint_path) - sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path) - sam_device = comfy.model_management.get_torch_device() - sam.to(device=sam_device) - sam.eval() - sam.model_name = model_file_name - predictor = SamPredictor(sam) - return (sam, predictor) - - -def sam_predict( - predictor, image, input_points, input_label, input_boxes, multimask_output -): - predictor.set_image(image) - masks, scores, logits = predictor.predict( - point_coords=input_points, - point_labels=input_label, - box=input_boxes, - multimask_output=multimask_output, - ) - return masks, scores, logits - - -def sam_predict_torch( - predictor, image_np, input_points, input_boxes, input_label, multimask_output -): - - image_np_rgb = image_np[..., :3] - predictor.set_image(image_np_rgb) - - transformed_boxes = predictor.transform.apply_boxes_torch( - input_boxes, image_np.shape[:2] - ) - - masks, scores, logits = predictor.predict_torch( - point_coords=input_points, - point_labels=input_label, - boxes=transformed_boxes, - multimask_output=multimask_output, - ) - - return masks, scores, logits - - -def get_local_filepath(url, dirname, local_file_name=None): - if not local_file_name: - parsed_url = urlparse(url) - local_file_name = os.path.basename(parsed_url.path) - - destination = folder_paths.get_full_path(dirname, local_file_name) - if destination: - return destination - - folder = os.path.join(folder_paths.models_dir, dirname) - if not os.path.exists(folder): - os.makedirs(folder) - - destination = os.path.join(folder, local_file_name) - if not os.path.exists(destination): - download_url_to_file(url, destination) - return destination - - -def load_groundingdino_model(model_name): - dino_model_args = local_groundingdino_SLConfig.fromfile( - get_local_filepath( - groundingdino_model_list[model_name]["config_url"], - groundingdino_model_dir_name, - ), - ) - - if dino_model_args.text_encoder_type == "bert-base-uncased": - dino_model_args.text_encoder_type = get_bert_base_uncased_model_path() - - dino = local_groundingdino_build_model(dino_model_args) - checkpoint = torch.load( - get_local_filepath( - groundingdino_model_list[model_name]["model_url"], - groundingdino_model_dir_name, - ), - ) - dino.load_state_dict( - local_groundingdino_clean_state_dict(checkpoint["model"]), strict=False - ) - device = comfy.model_management.get_torch_device() - dino.to(device=device) - dino.eval() - return dino - - -def list_groundingdino_model(): - return list(groundingdino_model_list.keys()) - - -def load_image(image_pil): - transform = T.Compose( - [ - T.RandomResize([800], max_size=1333), - T.ToTensor(), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ) - - image_transformed, _ = transform(image_pil, None) - return image_transformed - - -def guided_filter_alpha( - image: torch.Tensor, mask: torch.Tensor, filter_radius: int -) -> torch.Tensor: - sigma = 0.15 - d = filter_radius + 1 - mask = pil2tensor(tensor2pil(mask).convert("RGB")) - if not bool(d % 2): - d += 1 - s = sigma / 10 - i_dup = copy.deepcopy(image.cpu().numpy()) - a_dup = copy.deepcopy(mask.cpu().numpy()) - for index, image in enumerate(i_dup): - alpha_work = a_dup[index] - i_dup[index] = guidedFilter(image, alpha_work, d, s) - return torch.from_numpy(i_dup) - - -def histogram_remap( - image: torch.Tensor, blackpoint: float, whitepoint: float -) -> torch.Tensor: - bp = min(blackpoint, whitepoint - 0.001) - scale = 1 / (whitepoint - bp) - i_dup = copy.deepcopy(image.cpu().numpy()) - i_dup = np.clip((i_dup - bp) * scale, 0.0, 1.0) - return torch.from_numpy(i_dup) - - -def mask_edge_detail( - image: torch.Tensor, - mask: torch.Tensor, - detail_range: int = 8, - black_point: float = 0.01, - white_point: float = 0.99, -) -> torch.Tensor: - from pymatting import estimate_alpha_cf, fix_trimap - - d = detail_range * 5 + 1 - mask = pil2tensor(tensor2pil(mask).convert("RGB")) - if not bool(d % 2): - d += 1 - i_dup = copy.deepcopy(image.cpu().numpy().astype(np.float64)) - a_dup = copy.deepcopy(mask.cpu().numpy().astype(np.float64)) - for index, img in enumerate(i_dup): - trimap = a_dup[index][:, :, 0] # convert to single channel - if detail_range > 0: - trimap = cv2.GaussianBlur(trimap, (d, d), 0) - trimap = fix_trimap(trimap, black_point, white_point) - alpha = estimate_alpha_cf( - img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500} - ) - a_dup[index] = np.stack([alpha, alpha, alpha], axis=-1) # convert back to rgb - return torch.from_numpy(a_dup.astype(np.float32)) - - -def generate_VITMatte_trimap( - mask: torch.Tensor, erode_kernel_size: int, dilate_kernel_size: int -) -> Image: - def g_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10): - erode_kernel = np.ones((erode_kernel_size, erode_kernel_size), np.uint8) - dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8) - eroded = cv2.erode(mask, erode_kernel, iterations=5) - dilated = cv2.dilate(mask, dilate_kernel, iterations=5) - trimap = np.zeros_like(mask) - trimap[dilated == 255] = 128 - trimap[eroded == 255] = 255 - return trimap - - mask = mask.squeeze(0).cpu().detach().numpy().astype(np.uint8) * 255 - trimap = g_trimap(mask, erode_kernel_size, dilate_kernel_size).astype(np.float32) - trimap[trimap == 128] = 0.5 - trimap[trimap == 255] = 1 - trimap = torch.from_numpy(trimap).unsqueeze(0) - - return tensor2pil(trimap).convert("L") - - -def generate_VITMatte( - image: Image, - trimap: Image, - local_files_only: bool = False, - device: str = "cpu", - max_megapixels: float = 2.0, -) -> Image: - if image.mode != "RGB": - image = image.convert("RGB") - if trimap.mode != "L": - trimap = trimap.convert("L") - max_megapixels *= 1048576 - width, height = image.size - ratio = width / height - target_width = math.sqrt(ratio * max_megapixels) - target_height = target_width / ratio - target_width = int(target_width) - target_height = int(target_height) - if width * height > max_megapixels: - image = image.resize((target_width, target_height), Image.BILINEAR) - trimap = trimap.resize((target_width, target_height), Image.BILINEAR) - print( - f"vitmatte image size {width}x{height} too large, resize to {target_width}x{target_height} for processing." - ) - model_name = "hustvl/vitmatte-small-composition-1k" - if device == "cpu": - device = torch.device("cpu") - else: - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - print( - "vitmatte device is set to cuda, but not available, using cpu instead." - ) - device = torch.device("cpu") - vit_matte_model = load_VITMatte_model( - model_name=model_name, local_files_only=local_files_only - ) - vit_matte_model.model.to(device) - print( - f"vitmatte processing, image size = {image.width}x{image.height}, device = {device}." - ) - inputs = vit_matte_model.processor( - images=image, trimaps=trimap, return_tensors="pt" - ) - with torch.no_grad(): - inputs = {k: v.to(device) for k, v in inputs.items()} - predictions = vit_matte_model.model(**inputs).alphas - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - mask = tensor2pil(predictions).convert("L") - mask = mask.crop( - (0, 0, image.width, image.height) - ) # remove padding that the prediction appends (works in 32px tiles) - if width * height > max_megapixels: - mask = mask.resize((width, height), Image.BILINEAR) - return mask - - -class VITMatteModel: - def __init__(self, model, processor): - self.model = model - self.processor = processor - - -def load_VITMatte_model(model_name: str, local_files_only: bool = False) -> object: - # if local_files_only: - # model_name = Path(os.path.join(folder_paths.models_dir, "vitmatte")) - model_name = Path(os.path.join(folder_paths.models_dir, "vitmatte")) - from transformers import VitMatteForImageMatting, VitMatteImageProcessor - - model = VitMatteForImageMatting.from_pretrained( - model_name, local_files_only=local_files_only - ) - processor = VitMatteImageProcessor.from_pretrained( - model_name, local_files_only=local_files_only - ) - vitmatte = VITMatteModel(model, processor) - return vitmatte - - -def pil2tensor(image: Image) -> torch.Tensor: - return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) - - -def tensor2pil(t_image: torch.Tensor) -> Image: - return Image.fromarray( - np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) - ) - - -def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]: - if len(tensor.shape) == 3: # Single image - return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8) - else: # Batch of images - return [ - np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor - ] - - -def mask2image(mask: torch.Tensor) -> Image: - masks = tensor2np(mask) - for m in masks: - _mask = Image.fromarray(m).convert("L") - _image = Image.new("RGBA", _mask.size, color="white") - _image = Image.composite( - _image, Image.new("RGBA", _mask.size, color="black"), _mask - ) - return _image - - -def image2mask(image: Image) -> torch.Tensor: - _image = image.convert("RGBA") - alpha = _image.split()[0] - bg = Image.new("L", _image.size) - _image = Image.merge("RGBA", (bg, bg, bg, alpha)) - ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()]) - return ret_mask - - -def RGB2RGBA(image: Image, mask: Image) -> Image: - (R, G, B) = image.convert("RGB").split() - return Image.merge("RGBA", (R, G, B, mask.convert("L"))) - - -def groundingdino_predict(dino_model, image_pil, prompt, box_threshold, text_threshold): - image = load_image(image_pil) - - boxes, logits, phrases = predict( - model=dino_model, - image=image, - caption=prompt, - box_threshold=box_threshold, - text_threshold=text_threshold, - ) - - filt_mask = logits > box_threshold - boxes_filt = boxes.clone() - boxes_filt = boxes_filt[filt_mask] - H, W = image_pil.size[1], image_pil.size[0] - for i in range(boxes_filt.size(0)): - boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) - boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 - boxes_filt[i][2:] += boxes_filt[i][:2] - return boxes_filt - - class BizyAirSegmentAnythingText: API_URL = f"{BIZYAIR_SERVER_ADDRESS}/supernode/sam" @@ -714,272 +267,11 @@ def VALIDATE_INPUTS(s, image, is_point): return True -class BizyAirSAMModelLoader: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model_name": (list_sam_model(),), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ( - "SAM_MODEL", - "SAM_PREDICTOR", - ) - - def main(self, model_name): - sam_model, sam_predictor = load_sam_model(model_name) - return ( - sam_model, - sam_predictor, - ) - - -class BizyAirGroundingDinoModelLoader: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model_name": (list_groundingdino_model(),), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ("GROUNDING_DINO_MODEL",) - - def main(self, model_name): - dino_model = load_groundingdino_model(model_name) - return (dino_model,) - - -class BizyAirGroundingDinoSAMSegment: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "sam_model": ("SAM_MODEL", {}), - "grounding_dino_model": ("GROUNDING_DINO_MODEL", {}), - "sam_predictor": ("SAM_PREDICTOR", {}), - "image": ("IMAGE", {}), - "prompt": ("STRING", {}), - "box_threshold": ( - "FLOAT", - {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, - ), - "text_threshold": ( - "FLOAT", - {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, - ), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ("IMAGE", "MASK") - - def main( - self, - grounding_dino_model, - sam_model, - sam_predictor, - image, - prompt, - box_threshold, - text_threshold, - ): - res_images = [] - res_masks = [] - multimask_output = False - for item in image: - item = Image.fromarray( - # np.clip(255. * item.cpu().numpy(), 0, 255).astype(np.uint8)).convert('RGBA') - np.clip(255.0 * item.cpu().numpy(), 0, 255).astype(np.uint8) - ) - img = np.array(item) - - boxes = groundingdino_predict( - grounding_dino_model, item, prompt, box_threshold, text_threshold - ) - - if boxes.shape[0] == 0: - break - sam_device = comfy.model_management.get_torch_device() - boxes = boxes.to(sam_device) - masks, scores, logits = sam_predict_torch( - sam_predictor, - img, - None, - boxes, - None, - multimask_output, - ) - outimage, mask_image = save_masks(masks, img) - print("why image.type", type(outimage)) - print("why mask_image.type", type(mask_image)) - images = (torch.from_numpy(outimage).float() / 255.0).unsqueeze(0) - masks = (torch.from_numpy(mask_image).float() / 255.0).unsqueeze(0) - print("1111") - print("why images.shape: ", images.shape) - res_images.append(images) - res_masks.append(masks) - if len(res_images) > 1: - output_image = torch.cat(res_images, dim=0) - output_mask = torch.cat(res_masks, dim=0) - else: - output_image = res_images[0] - output_mask = res_masks[0] - print("why outPUt: ", output_image.shape) - print("why outPUt: ", output_mask.shape) - return (output_image, output_mask) - # return (image, torch.cat(res_masks, dim=0)) - - -class BizyAirTrimapGenerate: - @classmethod - def INPUT_TYPES(cls): - method_list = [ - "VITMatte", - "VITMatte(local)", - "PyMatting", - "GuidedFilter", - ] - return { - "required": { - "image": ("IMAGE", {}), - "mask": ("MASK",), - "detail_method": (method_list,), - "detail_erode": ( - "INT", - {"default": 6, "min": 1, "max": 255, "step": 1}, - ), - "detail_dilate": ( - "INT", - {"default": 6, "min": 1, "max": 255, "step": 1}, - ), - "black_point": ( - "FLOAT", - { - "default": 0.15, - "min": 0.01, - "max": 0.98, - "step": 0.01, - "display": "slider", - }, - ), - "white_point": ( - "FLOAT", - { - "default": 0.99, - "min": 0.02, - "max": 0.99, - "step": 0.01, - "display": "slider", - }, - ), - "max_megapixels": ( - "FLOAT", - {"default": 2.0, "min": 1, "max": 999, "step": 0.1}, - ), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ( - "IMAGE", - "MASK", - ) - RETURN_NAMES = ( - "image", - "mask", - ) - - def main( - self, - image, - mask, - detail_method, - detail_erode, - detail_dilate, - black_point, - white_point, - max_megapixels, - ): - if detail_method == "VITMatte(local)": - local_files_only = True - else: - local_files_only = False - - ret_images = [] - ret_masks = [] - device = comfy.model_management.get_torch_device() - print("image.shape:", image.shape) - print("image.shape[0]", image.shape[0]) - for i in range(image.shape[0]): - img = torch.unsqueeze(image[i], 0) - img = pil2tensor(tensor2pil(img).convert("RGB")) - _image = tensor2pil(img).convert("RGBA") - - detail_range = detail_erode + detail_dilate - if detail_method == "GuidedFilter": - _mask = guided_filter_alpha(img, mask[i], detail_range // 6 + 1) - _mask = tensor2pil(histogram_remap(_mask, black_point, white_point)) - elif detail_method == "PyMatting": - _mask = tensor2pil( - mask_edge_detail( - img, mask[i], detail_range // 8 + 1, black_point, white_point - ) - ) - else: - print("why trimap") - _trimap = generate_VITMatte_trimap(mask[i], detail_erode, detail_dilate) - _mask = generate_VITMatte( - _image, - _trimap, - local_files_only=local_files_only, - device=device, - max_megapixels=max_megapixels, - ) - _mask = tensor2pil( - histogram_remap(pil2tensor(_mask), black_point, white_point) - ) - - # _mask = mask2image(_mask) - - _image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L")) - - ret_images.append(pil2tensor(_image)) - ret_masks.append(image2mask(_mask)) - if len(ret_masks) == 0: - _, height, width, _ = image.size() - empty_mask = torch.zeros( - (1, height, width), dtype=torch.uint8, device="cpu" - ) - return (empty_mask, empty_mask) - - return ( - torch.cat(ret_images, dim=0), - torch.cat(ret_masks, dim=0), - ) - - NODE_CLASS_MAPPINGS = { "BizyAirSegmentAnythingText": BizyAirSegmentAnythingText, "BizyAirSegmentAnythingPointBox": BizyAirSegmentAnythingPointBox, - "BizyAirGroundingDinoModelLoader": BizyAirGroundingDinoModelLoader, - "BizyAirSAMModelLoader": BizyAirSAMModelLoader, - "BizyAirGroundingDinoSAMSegment": BizyAirGroundingDinoSAMSegment, - "BizyAirTrimapGenerate": BizyAirTrimapGenerate, } NODE_DISPLAY_NAME_MAPPINGS = { "BizyAirSegmentAnythingText": "☁️BizyAir Text Guided SAM", "BizyAirSegmentAnythingPointBox": "☁️BizyAir Point-Box Guided SAM", - "BizyAirGroundingDinoModelLoader": "☁️BizyAir Load GroundingDino Model", - "BizyAirSAMModelLoader": "☁️BizyAir Load SAM Model", - "BizyAirGroundingDinoSAMSegment": "☁️BizyAir GroundingDinoSAMSegment", - "BizyAirTrimapGenerate": "☁️BizyAir Trimap Generate", } From 82755e3a88525edec295ba1f6bcc256626e11704 Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Tue, 31 Dec 2024 15:10:40 +0800 Subject: [PATCH 03/13] refine the code --- __init__.py | 2 - bizyair_extras/__init__.py | 1 + bizyair_extras/nodes_segment_anything.py | 206 +++++++ examples/bizyair_segment_anything_ultra.json | 573 +++++++++++++++++++ sam.py | 457 --------------- 5 files changed, 780 insertions(+), 459 deletions(-) create mode 100644 bizyair_extras/nodes_segment_anything.py create mode 100644 examples/bizyair_segment_anything_ultra.json delete mode 100644 sam.py diff --git a/__init__.py b/__init__.py index acb20261..eeab0941 100644 --- a/__init__.py +++ b/__init__.py @@ -18,7 +18,6 @@ nodes, nodes_controlnet_aux, nodes_controlnet_union_sdxl, - sam, segment_anything, showcase, supernode, @@ -37,7 +36,6 @@ def update_mappings(module): update_mappings(nodes_controlnet_union_sdxl) update_mappings(mzkolors) update_mappings(segment_anything) -update_mappings(sam) try: import bizy_server diff --git a/bizyair_extras/__init__.py b/bizyair_extras/__init__.py index a3a9edb0..418c4055 100644 --- a/bizyair_extras/__init__.py +++ b/bizyair_extras/__init__.py @@ -13,6 +13,7 @@ from .nodes_kolors_mz import * from .nodes_model_advanced import * from .nodes_sd3 import * +from .nodes_segment_anything import * from .nodes_testing_utils import * from .nodes_ultimatesdupscale import * from .nodes_upscale_model import * diff --git a/bizyair_extras/nodes_segment_anything.py b/bizyair_extras/nodes_segment_anything.py new file mode 100644 index 00000000..2e565413 --- /dev/null +++ b/bizyair_extras/nodes_segment_anything.py @@ -0,0 +1,206 @@ +import os +from pathlib import Path +from urllib.parse import urlparse + +import comfy.model_management +import folder_paths +import groundingdino.datasets.transforms as T +import numpy as np +import torch + +from bizyair import BizyAirBaseNode + +sam_model_list = { + "sam_vit_h (2.56GB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + }, + "sam_vit_l (1.25GB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" + }, + "sam_vit_b (375MB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" + }, + "sam_hq_vit_h (2.57GB)": { + "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" + }, + "sam_hq_vit_l (1.25GB)": { + "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" + }, + "sam_hq_vit_b (379MB)": { + "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" + }, + "mobile_sam(39MB)": { + "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" + }, +} + +groundingdino_model_dir_name = "grounding-dino" +groundingdino_model_list = { + "GroundingDINO_SwinT_OGC (694MB)": { + "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", + "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", + }, + "GroundingDINO_SwinB (938MB)": { + "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", + "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", + }, +} + + +def list_sam_model(): + return list(sam_model_list.keys()) + + +def list_groundingdino_model(): + return list(groundingdino_model_list.keys()) + + +class BizyAir_SAMModelLoader(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_name": (list_sam_model(),), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + # FUNCTION = "main" + RETURN_TYPES = ("SAM_PREDICTOR",) + NODE_DISPLAY_NAME = "☁️BizyAir Load SAM Model" + + +class BizyAir_GroundingDinoModelLoader(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_name": (list_groundingdino_model(),), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + # FUNCTION = "main" + RETURN_TYPES = ("GROUNDING_DINO_MODEL",) + NODE_DISPLAY_NAME = "☁️BizyAir Load GroundingDino Model" + + +class BizyAir_VITMatteModelLoader(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(cls): + method_list = [ + "VITMatte", + "VITMatte(local)", + ] + return { + "required": { + "detail_method": (method_list,), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + # FUNCTION = "main" + RETURN_TYPES = ( + "VitMatte_MODEL", + "VitMatte_predictor", + ) + NODE_DISPLAY_NAME = "☁️BizyAir Load VITMatte Model" + + +class BizyAir_GroundingDinoSAMSegment(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "grounding_dino_model": ("GROUNDING_DINO_MODEL", {}), + "sam_predictor": ("SAM_PREDICTOR", {}), + "image": ("IMAGE", {}), + "prompt": ("STRING", {}), + "box_threshold": ( + "FLOAT", + {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, + ), + "text_threshold": ( + "FLOAT", + {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + # FUNCTION = "main" + RETURN_TYPES = ("IMAGE", "MASK") + NODE_DISPLAY_NAME = "☁️BizyAir GroundingDinoSAMSegment" + + +class BizyAir_TrimapGenerate(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "detail_erode": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "detail_dilate": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + # FUNCTION = "main" + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("trimap",) + NODE_DISPLAY_NAME = "☁️BizyAir Trimap Generate" + + +class BizyAir_VITMattePredict(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE", {}), + "trimap": ("MASK",), + "vitmatte_model": ("VitMatte_MODEL", {}), + "vitmatte_predictor": ("VitMatte_predictor", {}), + "black_point": ( + "FLOAT", + { + "default": 0.15, + "min": 0.01, + "max": 0.98, + "step": 0.01, + "display": "slider", + }, + ), + "white_point": ( + "FLOAT", + { + "default": 0.99, + "min": 0.02, + "max": 0.99, + "step": 0.01, + "display": "slider", + }, + ), + "max_megapixels": ( + "FLOAT", + {"default": 2.0, "min": 1, "max": 999, "step": 0.1}, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + # FUNCTION = "main" + RETURN_TYPES = ( + "IMAGE", + "MASK", + ) + RETURN_NAMES = ( + "image", + "mask", + ) + NODE_DISPLAY_NAME = "☁️BizyAir VITMatte Predict" diff --git a/examples/bizyair_segment_anything_ultra.json b/examples/bizyair_segment_anything_ultra.json new file mode 100644 index 00000000..947a4538 --- /dev/null +++ b/examples/bizyair_segment_anything_ultra.json @@ -0,0 +1,573 @@ +{ + "last_node_id": 23, + "last_link_id": 49, + "nodes": [ + { + "id": 22, + "type": "BizyAir_TrimapGenerate", + "pos": [ + 691.5542602539062, + 25.333850860595703 + ], + "size": [ + 315, + 82 + ], + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 42 + } + ], + "outputs": [ + { + "name": "trimap", + "type": "MASK", + "links": [ + 46, + 49 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_TrimapGenerate" + }, + "widgets_values": [ + 6, + 6 + ] + }, + { + "id": 19, + "type": "BizyAir_VITMatteModelLoader", + "pos": [ + 668.5579833984375, + -124.51361846923828 + ], + "size": [ + 365.4000244140625, + 78 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "VitMatte_MODEL", + "type": "VitMatte_MODEL", + "links": [ + 44 + ], + "slot_index": 0 + }, + { + "name": "VitMatte_predictor", + "type": "VitMatte_predictor", + "links": [ + 45 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "BizyAir_VITMatteModelLoader" + }, + "widgets_values": [ + "VITMatte" + ] + }, + { + "id": 23, + "type": "BizyAir_VITMattePredict", + "pos": [ + 1073.8350830078125, + -53.61327362060547 + ], + "size": [ + 327.5999755859375, + 166 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 43 + }, + { + "name": "trimap", + "type": "MASK", + "link": 46 + }, + { + "name": "vitmatte_model", + "type": "VitMatte_MODEL", + "link": 44 + }, + { + "name": "vitmatte_predictor", + "type": "VitMatte_predictor", + "link": 45 + } + ], + "outputs": [ + { + "name": "image", + "type": "IMAGE", + "links": [ + 48 + ], + "slot_index": 0 + }, + { + "name": "mask", + "type": "MASK", + "links": [ + 47 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "BizyAir_VITMattePredict" + }, + "widgets_values": [ + 0.15, + 0.99, + 2 + ] + }, + { + "id": 9, + "type": "PreviewImage", + "pos": [ + 1457.2244873046875, + -111.08157348632812 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 48 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 20, + "type": "BizyAir_GroundingDinoModelLoader", + "pos": [ + -119.09739685058594, + -87.74108123779297 + ], + "size": [ + 415.8000183105469, + 58 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "GROUNDING_DINO_MODEL", + "type": "GROUNDING_DINO_MODEL", + "links": [ + 35 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_GroundingDinoModelLoader" + }, + "widgets_values": [ + "GroundingDINO_SwinT_OGC (694MB)" + ] + }, + { + "id": 18, + "type": "BizyAir_SAMModelLoader", + "pos": [ + -24.793033599853516, + 40.01811218261719 + ], + "size": [ + 315, + 58 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "SAM_PREDICTOR", + "type": "SAM_PREDICTOR", + "links": [ + 32, + 36 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_SAMModelLoader" + }, + "widgets_values": [ + "sam_vit_h (2.56GB)" + ] + }, + { + "id": 5, + "type": "MaskPreview+", + "pos": [ + 338.49713134765625, + 236.64303588867188 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 41 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "MaskPreview+" + }, + "widgets_values": [] + }, + { + "id": 4, + "type": "PreviewImage", + "pos": [ + 585.2964477539062, + 235.61814880371094 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 40 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 8, + "type": "MaskPreview+", + "pos": [ + 852.276611328125, + 234.75250244140625 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 49 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "MaskPreview+" + }, + "widgets_values": [] + }, + { + "id": 6, + "type": "LoadImage", + "pos": [ + -29.21724510192871, + 167.51760864257812 + ], + "size": [ + 315, + 314 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 38 + ], + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "0809.png", + "image" + ] + }, + { + "id": 21, + "type": "BizyAir_GroundingDinoSAMSegment", + "pos": [ + 334.36334228515625, + -26.426876068115234 + ], + "size": [ + 286.7333068847656, + 146 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "grounding_dino_model", + "type": "GROUNDING_DINO_MODEL", + "link": 35 + }, + { + "name": "sam_predictor", + "type": "SAM_PREDICTOR", + "link": 36 + }, + { + "name": "image", + "type": "IMAGE", + "link": 38 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 40, + 43 + ], + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": [ + 41, + 42 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "BizyAir_GroundingDinoSAMSegment" + }, + "widgets_values": [ + "lion", + 0.3, + 0.3 + ] + }, + { + "id": 15, + "type": "MaskPreview+", + "pos": [ + 1448.48388671875, + 234.31948852539062 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 47 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "MaskPreview+" + }, + "widgets_values": [] + } + ], + "links": [ + [ + 32, + 18, + 0, + 17, + 1, + "SAM_PREDICTOR" + ], + [ + 35, + 20, + 0, + 21, + 0, + "GROUNDING_DINO_MODEL" + ], + [ + 36, + 18, + 0, + 21, + 1, + "SAM_PREDICTOR" + ], + [ + 38, + 6, + 0, + 21, + 2, + "IMAGE" + ], + [ + 40, + 21, + 0, + 4, + 0, + "IMAGE" + ], + [ + 41, + 21, + 1, + 5, + 0, + "MASK" + ], + [ + 42, + 21, + 1, + 22, + 0, + "MASK" + ], + [ + 43, + 21, + 0, + 23, + 0, + "IMAGE" + ], + [ + 44, + 19, + 0, + 23, + 2, + "VitMatte_MODEL" + ], + [ + 45, + 19, + 1, + 23, + 3, + "VitMatte_predictor" + ], + [ + 46, + 22, + 0, + 23, + 1, + "MASK" + ], + [ + 47, + 23, + 1, + 15, + 0, + "MASK" + ], + [ + 48, + 23, + 0, + 9, + 0, + "IMAGE" + ], + [ + 49, + 22, + 0, + 8, + 0, + "MASK" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 0.42409761837248516, + "offset": [ + 259.23479036354416, + 425.1317591221424 + ] + } + }, + "version": 0.4 +} diff --git a/sam.py b/sam.py deleted file mode 100644 index 80911f35..00000000 --- a/sam.py +++ /dev/null @@ -1,457 +0,0 @@ -import os -from pathlib import Path -from urllib.parse import urlparse - -import comfy.model_management -import folder_paths -import groundingdino.datasets.transforms as T -import numpy as np -import torch -from PIL import Image -from segment_anything_hq import SamPredictor, sam_model_registry - -from .sam_func import * - - -class BizyAirSAMModelLoader: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model_name": (list_sam_model(),), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ("SAM_PREDICTOR",) - - def main(self, model_name): - sam_checkpoint_path = get_local_filepath( - sam_model_list[model_name]["model_url"], sam_model_dir_name - ) - model_file_name = os.path.basename(sam_checkpoint_path) - model_type = model_file_name.split(".")[0] - if "hq" not in model_type and "mobile" not in model_type: - model_type = "_".join(model_type.split("_")[1:-1]) - sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path) - sam_device = comfy.model_management.get_torch_device() - sam.to(device=sam_device) - sam.eval() - sam.model_name = model_file_name - predictor = SamPredictor(sam) - - return (predictor,) - - -class BizyAirGroundingDinoModelLoader: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model_name": (list_groundingdino_model(),), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ("GROUNDING_DINO_MODEL",) - - def main(self, model_name): - dino_model = load_groundingdino_model(model_name) - return (dino_model,) - - -class BizyAirVITMatteModelLoader: - @classmethod - def INPUT_TYPES(cls): - method_list = [ - "VITMatte", - "VITMatte(local)", - ] - return { - "required": { - "detail_method": (method_list,), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ( - "VitMatte_MODEL", - "VitMatte_predictor", - ) - - def main(self, detail_method): - if detail_method == "VITMatte(local)": - local_files_only = True - else: - local_files_only = False - - model_name = Path(os.path.join(folder_paths.models_dir, "vitmatte")) - from transformers import VitMatteForImageMatting, VitMatteImageProcessor - - device = comfy.model_management.get_torch_device() - - model = VitMatteForImageMatting.from_pretrained( - model_name, local_files_only=local_files_only - ) - processor = VitMatteImageProcessor.from_pretrained( - model_name, local_files_only=local_files_only - ) - model.to(device) - return (model, processor) - - -class BizyAirGroundingDinoSAMSegment: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "grounding_dino_model": ("GROUNDING_DINO_MODEL", {}), - "sam_predictor": ("SAM_PREDICTOR", {}), - "image": ("IMAGE", {}), - "prompt": ("STRING", {}), - "box_threshold": ( - "FLOAT", - {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, - ), - "text_threshold": ( - "FLOAT", - {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01}, - ), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ("IMAGE", "MASK") - - def main( - self, - grounding_dino_model, - sam_predictor, - image, - prompt, - box_threshold, - text_threshold, - ): - res_images = [] - res_masks = [] - multimask_output = False - for item in image: - item = Image.fromarray( - # np.clip(255. * item.cpu().numpy(), 0, 255).astype(np.uint8)).convert('RGBA') - np.clip(255.0 * item.cpu().numpy(), 0, 255).astype(np.uint8) - ) - img = np.array(item) - - boxes = groundingdino_predict( - grounding_dino_model, item, prompt, box_threshold, text_threshold - ) - - if boxes.shape[0] == 0: - break - sam_device = comfy.model_management.get_torch_device() - boxes = boxes.to(sam_device) - masks, scores, logits = sam_predict_torch( - sam_predictor, - img, - None, - boxes, - None, - multimask_output, - ) - outimage, mask_image = save_masks(masks, img) - images = (torch.from_numpy(outimage).float() / 255.0).unsqueeze(0) - masks = (torch.from_numpy(mask_image).float() / 255.0).unsqueeze(0) - res_images.append(images) - res_masks.append(masks) - if len(res_images) > 1: - output_image = torch.cat(res_images, dim=0) - output_mask = torch.cat(res_masks, dim=0) - else: - output_image = res_images[0] - output_mask = res_masks[0] - return (output_image, output_mask) - # return (image, torch.cat(res_masks, dim=0)) - - -class BizyAirTrimapGenerate1: - @classmethod - def INPUT_TYPES(cls): - method_list = [ - "VITMatte", - "VITMatte(local)", - "PyMatting", - "GuidedFilter", - ] - return { - "required": { - "image": ("IMAGE", {}), - "mask": ("MASK",), - "detail_method": (method_list,), - "detail_erode": ( - "INT", - {"default": 6, "min": 1, "max": 255, "step": 1}, - ), - "detail_dilate": ( - "INT", - {"default": 6, "min": 1, "max": 255, "step": 1}, - ), - "black_point": ( - "FLOAT", - { - "default": 0.15, - "min": 0.01, - "max": 0.98, - "step": 0.01, - "display": "slider", - }, - ), - "white_point": ( - "FLOAT", - { - "default": 0.99, - "min": 0.02, - "max": 0.99, - "step": 0.01, - "display": "slider", - }, - ), - "max_megapixels": ( - "FLOAT", - {"default": 2.0, "min": 1, "max": 999, "step": 0.1}, - ), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ( - "IMAGE", - "MASK", - "MASK", - ) - RETURN_NAMES = ("image", "mask", "trimap") - - def main( - self, - image, - mask, - trimap, - detail_method, - detail_erode, - detail_dilate, - black_point, - white_point, - max_megapixels, - ): - if detail_method == "VITMatte(local)": - local_files_only = True - else: - local_files_only = False - - ret_images = [] - ret_masks = [] - device = comfy.model_management.get_torch_device() - print("image.shape:", image.shape) - print("image.shape[0]", image.shape[0]) - for i in range(image.shape[0]): - img = torch.unsqueeze(image[i], 0) - img = pil2tensor(tensor2pil(img).convert("RGB")) - _image = tensor2pil(img).convert("RGBA") - - detail_range = detail_erode + detail_dilate - if detail_method == "GuidedFilter": - _mask = guided_filter_alpha(img, mask[i], detail_range // 6 + 1) - _mask = tensor2pil(histogram_remap(_mask, black_point, white_point)) - elif detail_method == "PyMatting": - _mask = tensor2pil( - mask_edge_detail( - img, mask[i], detail_range // 8 + 1, black_point, white_point - ) - ) - else: - _trimap = generate_VITMatte_trimap(mask[i], detail_erode, detail_dilate) - _mask = generate_VITMatte( - _image, - _trimap, - local_files_only=local_files_only, - device=device, - max_megapixels=max_megapixels, - ) - _mask = tensor2pil( - histogram_remap(pil2tensor(_mask), black_point, white_point) - ) - - # _mask = mask2image(_mask) - - _image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L")) - - ret_images.append(pil2tensor(_image)) - ret_masks.append(image2mask(_mask)) - if len(ret_masks) == 0: - _, height, width, _ = image.size() - empty_mask = torch.zeros( - (1, height, width), dtype=torch.uint8, device="cpu" - ) - return (empty_mask, empty_mask) - - return ( - torch.cat(ret_images, dim=0), - torch.cat(ret_masks, dim=0), - pil2tensor(_trimap), - ) - - -class BizyAirTrimapGenerate: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "detail_erode": ( - "INT", - {"default": 6, "min": 1, "max": 255, "step": 1}, - ), - "detail_dilate": ( - "INT", - {"default": 6, "min": 1, "max": 255, "step": 1}, - ), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ("MASK",) - RETURN_NAMES = ("trimap",) - - def main( - self, - mask, - detail_erode, - detail_dilate, - ): - - ret_masks = [] - - for i in range(mask.shape[0]): - _trimap = generate_VITMatte_trimap(mask[i], detail_erode, detail_dilate) - _trimap_tensor = pil2tensor(_trimap) - ret_masks.append(_trimap_tensor) - - return (torch.cat(ret_masks, dim=0),) - - -class BizyAirVITMattePredict: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ("IMAGE", {}), - "trimap": ("MASK",), - "vitmatte_model": ("VitMatte_MODEL", {}), - "vitmatte_predictor": ("VitMatte_predictor", {}), - "black_point": ( - "FLOAT", - { - "default": 0.15, - "min": 0.01, - "max": 0.98, - "step": 0.01, - "display": "slider", - }, - ), - "white_point": ( - "FLOAT", - { - "default": 0.99, - "min": 0.02, - "max": 0.99, - "step": 0.01, - "display": "slider", - }, - ), - "max_megapixels": ( - "FLOAT", - {"default": 2.0, "min": 1, "max": 999, "step": 0.1}, - ), - } - } - - CATEGORY = "☁️BizyAir/segment-anything" - FUNCTION = "main" - RETURN_TYPES = ( - "IMAGE", - "MASK", - ) - RETURN_NAMES = ( - "image", - "mask", - ) - - def main( - self, - image, - trimap, - vitmatte_model, - vitmatte_predictor, - black_point, - white_point, - max_megapixels, - ): - - ret_images = [] - ret_masks = [] - device = comfy.model_management.get_torch_device() - - for i in range(image.shape[0]): - img = torch.unsqueeze(image[i], 0) - img = pil2tensor(tensor2pil(img).convert("RGB")) - _image = tensor2pil(img).convert("RGBA") - _mask = generate_VITMatte( - vitmatte_model, - vitmatte_predictor, - _image, - tensor2pil(trimap[i]), - device=device, - max_megapixels=max_megapixels, - ) - _mask = tensor2pil( - histogram_remap(pil2tensor(_mask), black_point, white_point) - ) - - _image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L")) - - ret_images.append(pil2tensor(_image)) - ret_masks.append(image2mask(_mask)) - if len(ret_masks) == 0: - _, height, width, _ = image.size() - empty_mask = torch.zeros( - (1, height, width), dtype=torch.uint8, device="cpu" - ) - return (empty_mask, empty_mask) - - return ( - torch.cat(ret_images, dim=0), - torch.cat(ret_masks, dim=0), - ) - - -NODE_CLASS_MAPPINGS = { - "BizyAirGroundingDinoModelLoader": BizyAirGroundingDinoModelLoader, - "BizyAirSAMModelLoader": BizyAirSAMModelLoader, - "BizyAirVITMatteModelLoader": BizyAirVITMatteModelLoader, - "BizyAirGroundingDinoSAMSegment": BizyAirGroundingDinoSAMSegment, - "BizyAirTrimapGenerate": BizyAirTrimapGenerate, - "BizyAirVITMattePredict": BizyAirVITMattePredict, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "BizyAirGroundingDinoModelLoader": "☁️BizyAir Load GroundingDino Model", - "BizyAirSAMModelLoader": "☁️BizyAir Load SAM Model", - "BizyAirVITMatteModelLoader": "☁️BizyAir Load VITMatte Model", - "BizyAirGroundingDinoSAMSegment": "☁️BizyAir GroundingDinoSAMSegment", - "BizyAirTrimapGenerate": "☁️BizyAir Trimap Generate", - "BizyAirVITMattePredict": "☁️BizyAir VITMatte Predict", -} From 481cce3c1ee2f26c24d8eed86fbac90df57a80d6 Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Tue, 31 Dec 2024 15:42:11 +0800 Subject: [PATCH 04/13] refine the code --- sam_func.py | 381 ---------------------------------------------------- 1 file changed, 381 deletions(-) delete mode 100644 sam_func.py diff --git a/sam_func.py b/sam_func.py deleted file mode 100644 index 883419d5..00000000 --- a/sam_func.py +++ /dev/null @@ -1,381 +0,0 @@ -import copy -import glob -import math -import os -from typing import List -from urllib.parse import urlparse - -import comfy.model_management -import cv2 -import folder_paths -import groundingdino.datasets.transforms as T -import numpy as np -import torch -from groundingdino.models import build_model as local_groundingdino_build_model -from groundingdino.util.inference import predict -from groundingdino.util.slconfig import SLConfig as local_groundingdino_SLConfig -from groundingdino.util.utils import ( - clean_state_dict as local_groundingdino_clean_state_dict, -) -from PIL import Image -from torch.hub import download_url_to_file - -try: - from cv2.ximgproc import guidedFilter -except ImportError: - # print(e) - print( - f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'" - f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'." - f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m" - ) - - -sam_model_dir_name = "sams" -sam_model_list = { - "sam_vit_h (2.56GB)": { - "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" - }, - "sam_vit_l (1.25GB)": { - "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" - }, - "sam_vit_b (375MB)": { - "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" - }, - "sam_hq_vit_h (2.57GB)": { - "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" - }, - "sam_hq_vit_l (1.25GB)": { - "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" - }, - "sam_hq_vit_b (379MB)": { - "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" - }, - "mobile_sam(39MB)": { - "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" - }, -} - -groundingdino_model_dir_name = "grounding-dino" -groundingdino_model_list = { - "GroundingDINO_SwinT_OGC (694MB)": { - "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", - "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", - }, - "GroundingDINO_SwinB (938MB)": { - "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", - "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", - }, -} - - -def get_bert_base_uncased_model_path(): - comfy_bert_model_base = os.path.join(folder_paths.models_dir, "bert-base-uncased") - if glob.glob( - os.path.join(comfy_bert_model_base, "**/model.safetensors"), recursive=True - ): - print("grounding-dino is using models/bert-base-uncased") - return comfy_bert_model_base - return "bert-base-uncased" - - -def save_masks(outmasks, image): - if len(outmasks) == 0: - return - if len(outmasks.shape) > 3: - outmasks = outmasks.permute(1, 0, 2, 3) - outmasks = ( - outmasks.view(outmasks.shape[1], outmasks.shape[2], outmasks.shape[3]) - .cpu() - .numpy() - ) - - image_height, image_width, _ = image.shape - - img = np.zeros((image_height, image_width, 3), dtype=np.uint8) - mask_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - - for mask in outmasks: - non_zero_positions = np.nonzero(mask) - # 使用非零值位置从原始图像中提取实例部分 - instance_part = image[non_zero_positions] - # 将提取的实例部分应用到合并图像中对应的位置 - img[non_zero_positions] = instance_part - - mask_image[non_zero_positions] = 255 # 标记为白色 - return img, mask_image - - -def list_sam_model(): - return list(sam_model_list.keys()) - - -def sam_predict_torch( - predictor, image_np, input_points, input_boxes, input_label, multimask_output -): - - image_np_rgb = image_np[..., :3] - predictor.set_image(image_np_rgb) - - transformed_boxes = predictor.transform.apply_boxes_torch( - input_boxes, image_np.shape[:2] - ) - - masks, scores, logits = predictor.predict_torch( - point_coords=input_points, - point_labels=input_label, - boxes=transformed_boxes, - multimask_output=multimask_output, - ) - - return masks, scores, logits - - -def get_local_filepath(url, dirname, local_file_name=None): - if not local_file_name: - parsed_url = urlparse(url) - local_file_name = os.path.basename(parsed_url.path) - - destination = folder_paths.get_full_path(dirname, local_file_name) - if destination: - return destination - - folder = os.path.join(folder_paths.models_dir, dirname) - if not os.path.exists(folder): - os.makedirs(folder) - - destination = os.path.join(folder, local_file_name) - if not os.path.exists(destination): - download_url_to_file(url, destination) - return destination - - -def load_groundingdino_model(model_name): - dino_model_args = local_groundingdino_SLConfig.fromfile( - get_local_filepath( - groundingdino_model_list[model_name]["config_url"], - groundingdino_model_dir_name, - ), - ) - - if dino_model_args.text_encoder_type == "bert-base-uncased": - dino_model_args.text_encoder_type = get_bert_base_uncased_model_path() - - dino = local_groundingdino_build_model(dino_model_args) - checkpoint = torch.load( - get_local_filepath( - groundingdino_model_list[model_name]["model_url"], - groundingdino_model_dir_name, - ), - ) - dino.load_state_dict( - local_groundingdino_clean_state_dict(checkpoint["model"]), strict=False - ) - device = comfy.model_management.get_torch_device() - dino.to(device=device) - dino.eval() - return dino - - -def list_groundingdino_model(): - return list(groundingdino_model_list.keys()) - - -def load_image(image_pil): - transform = T.Compose( - [ - T.RandomResize([800], max_size=1333), - T.ToTensor(), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ) - - image_transformed, _ = transform(image_pil, None) - return image_transformed - - -def guided_filter_alpha( - image: torch.Tensor, mask: torch.Tensor, filter_radius: int -) -> torch.Tensor: - sigma = 0.15 - d = filter_radius + 1 - mask = pil2tensor(tensor2pil(mask).convert("RGB")) - if not bool(d % 2): - d += 1 - s = sigma / 10 - i_dup = copy.deepcopy(image.cpu().numpy()) - a_dup = copy.deepcopy(mask.cpu().numpy()) - for index, image in enumerate(i_dup): - alpha_work = a_dup[index] - i_dup[index] = guidedFilter(image, alpha_work, d, s) - return torch.from_numpy(i_dup) - - -def histogram_remap( - image: torch.Tensor, blackpoint: float, whitepoint: float -) -> torch.Tensor: - bp = min(blackpoint, whitepoint - 0.001) - scale = 1 / (whitepoint - bp) - i_dup = copy.deepcopy(image.cpu().numpy()) - i_dup = np.clip((i_dup - bp) * scale, 0.0, 1.0) - return torch.from_numpy(i_dup) - - -def mask_edge_detail( - image: torch.Tensor, - mask: torch.Tensor, - detail_range: int = 8, - black_point: float = 0.01, - white_point: float = 0.99, -) -> torch.Tensor: - from pymatting import estimate_alpha_cf, fix_trimap - - d = detail_range * 5 + 1 - mask = pil2tensor(tensor2pil(mask).convert("RGB")) - if not bool(d % 2): - d += 1 - i_dup = copy.deepcopy(image.cpu().numpy().astype(np.float64)) - a_dup = copy.deepcopy(mask.cpu().numpy().astype(np.float64)) - for index, img in enumerate(i_dup): - trimap = a_dup[index][:, :, 0] # convert to single channel - if detail_range > 0: - trimap = cv2.GaussianBlur(trimap, (d, d), 0) - trimap = fix_trimap(trimap, black_point, white_point) - alpha = estimate_alpha_cf( - img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500} - ) - a_dup[index] = np.stack([alpha, alpha, alpha], axis=-1) # convert back to rgb - return torch.from_numpy(a_dup.astype(np.float32)) - - -def generate_VITMatte_trimap( - mask: torch.Tensor, erode_kernel_size: int, dilate_kernel_size: int -) -> Image: - def g_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10): - erode_kernel = np.ones((erode_kernel_size, erode_kernel_size), np.uint8) - dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8) - eroded = cv2.erode(mask, erode_kernel, iterations=5) - dilated = cv2.dilate(mask, dilate_kernel, iterations=5) - trimap = np.zeros_like(mask) - trimap[dilated == 255] = 128 - trimap[eroded == 255] = 255 - return trimap - - mask = mask.squeeze(0).cpu().detach().numpy().astype(np.uint8) * 255 - trimap = g_trimap(mask, erode_kernel_size, dilate_kernel_size).astype(np.float32) - trimap[trimap == 128] = 0.5 - trimap[trimap == 255] = 1 - trimap = torch.from_numpy(trimap).unsqueeze(0) - - return tensor2pil(trimap).convert("L") - - -def generate_VITMatte( - vit_matte_model, - vitmatte_predictor, - image: Image, - trimap: Image, - device: str = "cpu", - max_megapixels: float = 2.0, -) -> Image: - if image.mode != "RGB": - image = image.convert("RGB") - if trimap.mode != "L": - trimap = trimap.convert("L") - max_megapixels *= 1048576 - width, height = image.size - ratio = width / height - target_width = math.sqrt(ratio * max_megapixels) - target_height = target_width / ratio - target_width = int(target_width) - target_height = int(target_height) - if width * height > max_megapixels: - image = image.resize((target_width, target_height), Image.BILINEAR) - trimap = trimap.resize((target_width, target_height), Image.BILINEAR) - print( - f"vitmatte image size {width}x{height} too large, resize to {target_width}x{target_height} for processing." - ) - - print( - f"vitmatte processing, image size = {image.width}x{image.height}, device = {device}." - ) - inputs = vitmatte_predictor(images=image, trimaps=trimap, return_tensors="pt") - with torch.no_grad(): - inputs = {k: v.to(device) for k, v in inputs.items()} - predictions = vit_matte_model(**inputs).alphas - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - mask = tensor2pil(predictions).convert("L") - mask = mask.crop( - (0, 0, image.width, image.height) - ) # remove padding that the prediction appends (works in 32px tiles) - if width * height > max_megapixels: - mask = mask.resize((width, height), Image.BILINEAR) - return mask - - -def pil2tensor(image: Image) -> torch.Tensor: - return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) - - -def tensor2pil(t_image: torch.Tensor) -> Image: - return Image.fromarray( - np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) - ) - - -def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]: - if len(tensor.shape) == 3: # Single image - return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8) - else: # Batch of images - return [ - np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor - ] - - -def mask2image(mask: torch.Tensor) -> Image: - masks = tensor2np(mask) - for m in masks: - _mask = Image.fromarray(m).convert("L") - _image = Image.new("RGBA", _mask.size, color="white") - _image = Image.composite( - _image, Image.new("RGBA", _mask.size, color="black"), _mask - ) - return _image - - -def image2mask(image: Image) -> torch.Tensor: - _image = image.convert("RGBA") - alpha = _image.split()[0] - bg = Image.new("L", _image.size) - _image = Image.merge("RGBA", (bg, bg, bg, alpha)) - ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()]) - return ret_mask - - -def RGB2RGBA(image: Image, mask: Image) -> Image: - (R, G, B) = image.convert("RGB").split() - return Image.merge("RGBA", (R, G, B, mask.convert("L"))) - - -def groundingdino_predict(dino_model, image_pil, prompt, box_threshold, text_threshold): - image = load_image(image_pil) - - boxes, logits, phrases = predict( - model=dino_model, - image=image, - caption=prompt, - box_threshold=box_threshold, - text_threshold=text_threshold, - ) - - filt_mask = logits > box_threshold - boxes_filt = boxes.clone() - boxes_filt = boxes_filt[filt_mask] - H, W = image_pil.size[1], image_pil.size[0] - for i in range(boxes_filt.size(0)): - boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) - boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 - boxes_filt[i][2:] += boxes_filt[i][:2] - return boxes_filt From e91c76820c2d4d545023ce4d1ba00c3466148ea6 Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Fri, 3 Jan 2025 20:51:54 +0800 Subject: [PATCH 05/13] refine the code --- bizyair_extras/nodes_segment_anything.py | 154 +++- examples/bizyair_segment_anything_ultra.json | 902 +++++++++++++++---- 2 files changed, 847 insertions(+), 209 deletions(-) diff --git a/bizyair_extras/nodes_segment_anything.py b/bizyair_extras/nodes_segment_anything.py index 2e565413..9dc2e593 100644 --- a/bizyair_extras/nodes_segment_anything.py +++ b/bizyair_extras/nodes_segment_anything.py @@ -1,37 +1,29 @@ -import os -from pathlib import Path from urllib.parse import urlparse -import comfy.model_management -import folder_paths -import groundingdino.datasets.transforms as T -import numpy as np -import torch - from bizyair import BizyAirBaseNode sam_model_list = { "sam_vit_h (2.56GB)": { "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" }, - "sam_vit_l (1.25GB)": { - "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" - }, - "sam_vit_b (375MB)": { - "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" - }, - "sam_hq_vit_h (2.57GB)": { - "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" - }, - "sam_hq_vit_l (1.25GB)": { - "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" - }, - "sam_hq_vit_b (379MB)": { - "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" - }, - "mobile_sam(39MB)": { - "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" - }, + # "sam_vit_l (1.25GB)": { + # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" + # }, + # "sam_vit_b (375MB)": { + # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" + # }, + # "sam_hq_vit_h (2.57GB)": { + # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" + # }, + # "sam_hq_vit_l (1.25GB)": { + # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" + # }, + # "sam_hq_vit_b (379MB)": { + # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" + # }, + # "mobile_sam(39MB)": { + # "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" + # }, } groundingdino_model_dir_name = "grounding-dino" @@ -40,10 +32,10 @@ "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", }, - "GroundingDINO_SwinB (938MB)": { - "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", - "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", - }, + # "GroundingDINO_SwinB (938MB)": { + # "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", + # "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", + # }, } @@ -204,3 +196,105 @@ def INPUT_TYPES(cls): "mask", ) NODE_DISPLAY_NAME = "☁️BizyAir VITMatte Predict" + + +class BizyAir_GuidedFilterPredict(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE", {}), + "mask": ("MASK",), + "detail_erode": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "detail_dilate": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "black_point": ( + "FLOAT", + { + "default": 0.15, + "min": 0.01, + "max": 0.98, + "step": 0.01, + "display": "slider", + }, + ), + "white_point": ( + "FLOAT", + { + "default": 0.99, + "min": 0.02, + "max": 0.99, + "step": 0.01, + "display": "slider", + }, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + # FUNCTION = "main" + RETURN_TYPES = ( + "IMAGE", + "MASK", + ) + RETURN_NAMES = ( + "image", + "mask", + ) + NODE_DISPLAY_NAME = "☁️BizyAir GuidedFilter Predict" + + +class BizyAir_PyMattingPredict(BizyAirBaseNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE", {}), + "mask": ("MASK",), + "detail_erode": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "detail_dilate": ( + "INT", + {"default": 6, "min": 1, "max": 255, "step": 1}, + ), + "black_point": ( + "FLOAT", + { + "default": 0.15, + "min": 0.01, + "max": 0.98, + "step": 0.01, + "display": "slider", + }, + ), + "white_point": ( + "FLOAT", + { + "default": 0.99, + "min": 0.02, + "max": 0.99, + "step": 0.01, + "display": "slider", + }, + ), + } + } + + CATEGORY = "☁️BizyAir/segment-anything" + # FUNCTION = "main" + RETURN_TYPES = ( + "IMAGE", + "MASK", + ) + RETURN_NAMES = ( + "image", + "mask", + ) + NODE_DISPLAY_NAME = "☁️BizyAir PyMatting Predict" diff --git a/examples/bizyair_segment_anything_ultra.json b/examples/bizyair_segment_anything_ultra.json index 947a4538..75b13856 100644 --- a/examples/bizyair_segment_anything_ultra.json +++ b/examples/bizyair_segment_anything_ultra.json @@ -1,100 +1,186 @@ { - "last_node_id": 23, - "last_link_id": 49, + "last_node_id": 40, + "last_link_id": 76, "nodes": [ { - "id": 22, - "type": "BizyAir_TrimapGenerate", + "id": 20, + "type": "BizyAir_GroundingDinoModelLoader", "pos": [ - 691.5542602539062, - 25.333850860595703 + -119.09739685058594, + -87.74108123779297 ], "size": [ - 315, - 82 + 415.8000183105469, + 58 ], "flags": {}, - "order": 7, + "order": 0, "mode": 0, - "inputs": [ + "inputs": [], + "outputs": [ { - "name": "mask", - "type": "MASK", - "link": 42 + "name": "GROUNDING_DINO_MODEL", + "type": "GROUNDING_DINO_MODEL", + "links": [ + 35 + ], + "slot_index": 0 } ], + "properties": { + "Node name for S&R": "BizyAir_GroundingDinoModelLoader" + }, + "widgets_values": [ + "GroundingDINO_SwinT_OGC (694MB)" + ] + }, + { + "id": 18, + "type": "BizyAir_SAMModelLoader", + "pos": [ + -24.793033599853516, + 40.01811218261719 + ], + "size": [ + 315, + 58 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], "outputs": [ { - "name": "trimap", - "type": "MASK", + "name": "SAM_PREDICTOR", + "type": "SAM_PREDICTOR", "links": [ - 46, - 49 + 32, + 36 ], "slot_index": 0 } ], "properties": { - "Node name for S&R": "BizyAir_TrimapGenerate" + "Node name for S&R": "BizyAir_SAMModelLoader" }, "widgets_values": [ - 6, - 6 + "sam_vit_h (2.56GB)" ] }, { - "id": 19, - "type": "BizyAir_VITMatteModelLoader", + "id": 6, + "type": "LoadImage", "pos": [ - 668.5579833984375, - -124.51361846923828 + -29.21724510192871, + 167.51760864257812 ], "size": [ - 365.4000244140625, - 78 + 315, + 314 ], "flags": {}, - "order": 0, + "order": 2, "mode": 0, "inputs": [], "outputs": [ { - "name": "VitMatte_MODEL", - "type": "VitMatte_MODEL", + "name": "IMAGE", + "type": "IMAGE", "links": [ - 44 + 38 ], "slot_index": 0 }, { - "name": "VitMatte_predictor", - "type": "VitMatte_predictor", - "links": [ - 45 - ], - "slot_index": 1 + "name": "MASK", + "type": "MASK", + "links": null } ], "properties": { - "Node name for S&R": "BizyAir_VITMatteModelLoader" + "Node name for S&R": "LoadImage" }, "widgets_values": [ - "VITMatte" + "0855.png", + "image" ] }, + { + "id": 24, + "type": "MaskToImage", + "pos": [ + 330.49664306640625, + 165.41400146484375 + ], + "size": [ + 264.5999755859375, + 26 + ], + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 51 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 50 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MaskToImage" + }, + "widgets_values": [] + }, + { + "id": 25, + "type": "PreviewImage", + "pos": [ + 343.9561767578125, + 253.86276245117188 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 50 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, { "id": 23, "type": "BizyAir_VITMattePredict", "pos": [ - 1073.8350830078125, - -53.61327362060547 + 1220.4493408203125, + 608.5986328125 ], "size": [ 327.5999755859375, 166 ], "flags": {}, - "order": 8, + "order": 10, "mode": 0, "inputs": [ { @@ -131,7 +217,7 @@ "name": "mask", "type": "MASK", "links": [ - 47 + 55 ], "slot_index": 1 } @@ -145,19 +231,149 @@ 2 ] }, + { + "id": 4, + "type": "PreviewImage", + "pos": [ + 352.6379699707031, + 547.1114501953125 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 40 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 19, + "type": "BizyAir_VITMatteModelLoader", + "pos": [ + 785.3675537109375, + 531.9297485351562 + ], + "size": [ + 365.4000244140625, + 78 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "VitMatte_MODEL", + "type": "VitMatte_MODEL", + "links": [ + 44 + ], + "slot_index": 0 + }, + { + "name": "VitMatte_predictor", + "type": "VitMatte_predictor", + "links": [ + 45 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "BizyAir_VITMatteModelLoader" + }, + "widgets_values": [ + "VITMatte" + ] + }, + { + "id": 28, + "type": "MaskToImage", + "pos": [ + 1612.9088134765625, + 588.6235961914062 + ], + "size": [ + 214.6072540283203, + 26 + ], + "flags": {}, + "order": 18, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 55 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 54 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MaskToImage" + }, + "widgets_values": [] + }, + { + "id": 29, + "type": "PreviewImage", + "pos": [ + 1620.6007080078125, + 678.995361328125 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 22, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 54 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, { "id": 9, "type": "PreviewImage", "pos": [ - 1457.2244873046875, - -111.08157348632812 + 1872.0692138671875, + 677.0736083984375 ], "size": [ 210, 246 ], "flags": {}, - "order": 10, + "order": 17, "mode": 0, "inputs": [ { @@ -173,116 +389,308 @@ "widgets_values": [] }, { - "id": 20, - "type": "BizyAir_GroundingDinoModelLoader", + "id": 26, + "type": "MaskToImage", "pos": [ - -119.09739685058594, - -87.74108123779297 + 810.7152709960938, + 827.2421875 ], "size": [ - 415.8000183105469, - 58 + 264.5999755859375, + 26 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 53 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 52 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MaskToImage" + }, + "widgets_values": [] + }, + { + "id": 27, + "type": "PreviewImage", + "pos": [ + 811.5810546875, + 920.4977416992188 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 19, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 52 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 21, + "type": "BizyAir_GroundingDinoSAMSegment", + "pos": [ + 334.36334228515625, + -26.426876068115234 + ], + "size": [ + 286.7333068847656, + 146 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "grounding_dino_model", + "type": "GROUNDING_DINO_MODEL", + "link": 35 + }, + { + "name": "sam_predictor", + "type": "SAM_PREDICTOR", + "link": 36 + }, + { + "name": "image", + "type": "IMAGE", + "link": 38 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 40, + 43, + 56, + 72 + ], + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": [ + 42, + 51, + 57, + 71 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "BizyAir_GroundingDinoSAMSegment" + }, + "widgets_values": [ + "girl", + 0.3, + 0.3 + ] + }, + { + "id": 38, + "type": "PreviewImage", + "pos": [ + 1439.570556640625, + -430.55450439453125 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 15, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 74 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 22, + "type": "BizyAir_TrimapGenerate", + "pos": [ + 803.268310546875, + 675.9127807617188 + ], + "size": [ + 315, + 82 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 42 + } ], - "flags": {}, - "order": 1, - "mode": 0, - "inputs": [], "outputs": [ { - "name": "GROUNDING_DINO_MODEL", - "type": "GROUNDING_DINO_MODEL", + "name": "trimap", + "type": "MASK", "links": [ - 35 + 46, + 53 ], "slot_index": 0 } ], "properties": { - "Node name for S&R": "BizyAir_GroundingDinoModelLoader" + "Node name for S&R": "BizyAir_TrimapGenerate" }, "widgets_values": [ - "GroundingDINO_SwinT_OGC (694MB)" + 4, + 6 ] }, { - "id": 18, - "type": "BizyAir_SAMModelLoader", + "id": 30, + "type": "BizyAir_PyMattingPredict", "pos": [ - -24.793033599853516, - 40.01811218261719 + 746.6868286132812, + -9.079998016357422 ], "size": [ - 315, - 58 + 340.20001220703125, + 150 ], "flags": {}, - "order": 2, + "order": 8, "mode": 0, - "inputs": [], + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 56 + }, + { + "name": "mask", + "type": "MASK", + "link": 57 + } + ], "outputs": [ { - "name": "SAM_PREDICTOR", - "type": "SAM_PREDICTOR", + "name": "image", + "type": "IMAGE", "links": [ - 32, - 36 + 60 ], "slot_index": 0 + }, + { + "name": "mask", + "type": "MASK", + "links": [ + 61 + ], + "slot_index": 1 } ], "properties": { - "Node name for S&R": "BizyAir_SAMModelLoader" + "Node name for S&R": "BizyAir_PyMattingPredict" }, "widgets_values": [ - "sam_vit_h (2.56GB)" + 3, + 6, + 0.15, + 0.99 ] }, { - "id": 5, - "type": "MaskPreview+", + "id": 33, + "type": "MaskToImage", "pos": [ - 338.49713134765625, - 236.64303588867188 + 1125.576171875, + -39.265621185302734 ], "size": [ - 210, - 246 + 214.6072540283203, + 26 ], "flags": {}, - "order": 6, + "order": 14, "mode": 0, "inputs": [ { "name": "mask", "type": "MASK", - "link": 41 + "link": 61 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 59 + ], + "slot_index": 0 } ], - "outputs": [], "properties": { - "Node name for S&R": "MaskPreview+" + "Node name for S&R": "MaskToImage" }, "widgets_values": [] }, { - "id": 4, + "id": 31, "type": "PreviewImage", "pos": [ - 585.2964477539062, - 235.61814880371094 + 1133.45654296875, + 48.70017623901367 ], "size": [ 210, 246 ], "flags": {}, - "order": 5, + "order": 20, "mode": 0, "inputs": [ { "name": "images", "type": "IMAGE", - "link": 40 + "link": 59 } ], "outputs": [], @@ -292,154 +700,146 @@ "widgets_values": [] }, { - "id": 8, - "type": "MaskPreview+", + "id": 34, + "type": "PreviewImage", "pos": [ - 852.276611328125, - 234.75250244140625 + 1377.463623046875, + 44.568546295166016 ], "size": [ 210, 246 ], "flags": {}, - "order": 9, + "order": 13, "mode": 0, "inputs": [ { - "name": "mask", - "type": "MASK", - "link": 49 + "name": "images", + "type": "IMAGE", + "link": 60 } ], "outputs": [], "properties": { - "Node name for S&R": "MaskPreview+" + "Node name for S&R": "PreviewImage" }, "widgets_values": [] }, { - "id": 6, - "type": "LoadImage", + "id": 37, + "type": "PreviewImage", "pos": [ - -29.21724510192871, - 167.51760864257812 + 1194.40283203125, + -424.7430114746094 ], "size": [ - 315, - 314 + 210, + 246 ], "flags": {}, - "order": 3, + "order": 21, "mode": 0, - "inputs": [], - "outputs": [ + "inputs": [ { - "name": "IMAGE", + "name": "images", "type": "IMAGE", - "links": [ - 38 - ], - "slot_index": 0 - }, - { - "name": "MASK", - "type": "MASK", - "links": null + "link": 76 } ], + "outputs": [], "properties": { - "Node name for S&R": "LoadImage" + "Node name for S&R": "PreviewImage" }, - "widgets_values": [ - "0809.png", - "image" - ] + "widgets_values": [] }, { - "id": 21, - "type": "BizyAir_GroundingDinoSAMSegment", + "id": 40, + "type": "BizyAir_GuidedFilterPredict", "pos": [ - 334.36334228515625, - -26.426876068115234 + 753.5757446289062, + -427.8138427734375 ], "size": [ - 286.7333068847656, - 146 + 378, + 150 ], "flags": {}, - "order": 4, + "order": 9, "mode": 0, "inputs": [ - { - "name": "grounding_dino_model", - "type": "GROUNDING_DINO_MODEL", - "link": 35 - }, - { - "name": "sam_predictor", - "type": "SAM_PREDICTOR", - "link": 36 - }, { "name": "image", "type": "IMAGE", - "link": 38 + "link": 72 + }, + { + "name": "mask", + "type": "MASK", + "link": 71 } ], "outputs": [ { - "name": "IMAGE", + "name": "image", "type": "IMAGE", "links": [ - 40, - 43 + 74 ], "slot_index": 0 }, { - "name": "MASK", + "name": "mask", "type": "MASK", "links": [ - 41, - 42 + 75 ], "slot_index": 1 } ], "properties": { - "Node name for S&R": "BizyAir_GroundingDinoSAMSegment" + "Node name for S&R": "BizyAir_GuidedFilterPredict" }, "widgets_values": [ - "lion", - 0.3, - 0.3 + 6, + 6, + 0.15, + 0.99 ] }, { - "id": 15, - "type": "MaskPreview+", + "id": 36, + "type": "MaskToImage", "pos": [ - 1448.48388671875, - 234.31948852539062 + 1170.7628173828125, + -506.9857482910156 ], "size": [ - 210, - 246 + 214.6072540283203, + 26 ], "flags": {}, - "order": 11, + "order": 16, "mode": 0, "inputs": [ { "name": "mask", "type": "MASK", - "link": 47 + "link": 75 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 76 + ], + "slot_index": 0 } ], - "outputs": [], "properties": { - "Node name for S&R": "MaskPreview+" + "Node name for S&R": "MaskToImage" }, "widgets_values": [] } @@ -485,14 +885,6 @@ 0, "IMAGE" ], - [ - 41, - 21, - 1, - 5, - 0, - "MASK" - ], [ 42, 21, @@ -534,38 +926,190 @@ "MASK" ], [ - 47, + 48, 23, + 0, + 9, + 0, + "IMAGE" + ], + [ + 50, + 24, + 0, + 25, + 0, + "IMAGE" + ], + [ + 51, + 21, 1, - 15, + 24, 0, "MASK" ], [ - 48, - 23, + 52, + 26, 0, - 9, + 27, 0, "IMAGE" ], [ - 49, + 53, 22, 0, - 8, + 26, + 0, + "MASK" + ], + [ + 54, + 28, + 0, + 29, + 0, + "IMAGE" + ], + [ + 55, + 23, + 1, + 28, + 0, + "MASK" + ], + [ + 56, + 21, + 0, + 30, + 0, + "IMAGE" + ], + [ + 57, + 21, + 1, + 30, + 1, + "MASK" + ], + [ + 59, + 33, + 0, + 31, + 0, + "IMAGE" + ], + [ + 60, + 30, + 0, + 34, + 0, + "IMAGE" + ], + [ + 61, + 30, + 1, + 33, + 0, + "MASK" + ], + [ + 71, + 21, + 1, + 40, + 1, + "MASK" + ], + [ + 72, + 21, + 0, + 40, + 0, + "IMAGE" + ], + [ + 74, + 40, + 0, + 38, + 0, + "IMAGE" + ], + [ + 75, + 40, + 1, + 36, 0, "MASK" + ], + [ + 76, + 36, + 0, + 37, + 0, + "IMAGE" ] ], - "groups": [], + "groups": [ + { + "id": 1, + "title": "VitMatte", + "bounding": [ + 696.8377685546875, + 384.8053894042969, + 1436.54345703125, + 805.4721069335938 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 2, + "title": "Group", + "bounding": [ + 697.36669921875, + -120.6985092163086, + 914.1190185546875, + 431.1031494140625 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 3, + "title": "Group", + "bounding": [ + 696.3088989257812, + -561.6924438476562, + 1012.4700927734375, + 393.0317077636719 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + } + ], "config": {}, "extra": { "ds": { - "scale": 0.42409761837248516, + "scale": 0.5209868481924371, "offset": [ - 259.23479036354416, - 425.1317591221424 + -45.97107468187877, + 655.7454599034186 ] } }, From 8cc8b4fea534e3ac54a928bbdfcc46d32436f1da Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Mon, 6 Jan 2025 09:46:02 +0800 Subject: [PATCH 06/13] refine the code --- bizyair_extras/__init__.py | 8 + bizyair_extras/nodes_segment_anything.py | 190 ++++---- examples/bizyair_segment_anything_ultra.json | 430 +++++-------------- 3 files changed, 224 insertions(+), 404 deletions(-) diff --git a/bizyair_extras/__init__.py b/bizyair_extras/__init__.py index 418c4055..b641db39 100644 --- a/bizyair_extras/__init__.py +++ b/bizyair_extras/__init__.py @@ -17,3 +17,11 @@ from .nodes_testing_utils import * from .nodes_ultimatesdupscale import * from .nodes_upscale_model import * + + +def update_mappings(module): + NODE_CLASS_MAPPINGS.update(**module.NODE_CLASS_MAPPINGS) + NODE_DISPLAY_NAME_MAPPINGS.update(**module.NODE_DISPLAY_NAME_MAPPINGS) + + +update_mappings(nodes_segment_anything) diff --git a/bizyair_extras/nodes_segment_anything.py b/bizyair_extras/nodes_segment_anything.py index 9dc2e593..9fe91d44 100644 --- a/bizyair_extras/nodes_segment_anything.py +++ b/bizyair_extras/nodes_segment_anything.py @@ -2,49 +2,51 @@ from bizyair import BizyAirBaseNode -sam_model_list = { - "sam_vit_h (2.56GB)": { - "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" - }, - # "sam_vit_l (1.25GB)": { - # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" - # }, - # "sam_vit_b (375MB)": { - # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" - # }, - # "sam_hq_vit_h (2.57GB)": { - # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" - # }, - # "sam_hq_vit_l (1.25GB)": { - # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" - # }, - # "sam_hq_vit_b (379MB)": { - # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" - # }, - # "mobile_sam(39MB)": { - # "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" - # }, -} +from .nodes_segment_anything_utils import * -groundingdino_model_dir_name = "grounding-dino" -groundingdino_model_list = { - "GroundingDINO_SwinT_OGC (694MB)": { - "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", - "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", - }, - # "GroundingDINO_SwinB (938MB)": { - # "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", - # "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", - # }, -} +# sam_model_list = { +# "sam_vit_h (2.56GB)": { +# "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" +# }, +# # "sam_vit_l (1.25GB)": { +# # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" +# # }, +# # "sam_vit_b (375MB)": { +# # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" +# # }, +# # "sam_hq_vit_h (2.57GB)": { +# # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" +# # }, +# # "sam_hq_vit_l (1.25GB)": { +# # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" +# # }, +# # "sam_hq_vit_b (379MB)": { +# # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" +# # }, +# # "mobile_sam(39MB)": { +# # "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" +# # }, +# } +# groundingdino_model_dir_name = "grounding-dino" +# groundingdino_model_list = { +# "GroundingDINO_SwinT_OGC (694MB)": { +# "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", +# "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", +# }, +# # "GroundingDINO_SwinB (938MB)": { +# # "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", +# # "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", +# # }, +# } -def list_sam_model(): - return list(sam_model_list.keys()) +# def list_sam_model(): +# return list(sam_model_list.keys()) -def list_groundingdino_model(): - return list(groundingdino_model_list.keys()) + +# def list_groundingdino_model(): +# return list(groundingdino_model_list.keys()) class BizyAir_SAMModelLoader(BizyAirBaseNode): @@ -198,13 +200,19 @@ def INPUT_TYPES(cls): NODE_DISPLAY_NAME = "☁️BizyAir VITMatte Predict" -class BizyAir_GuidedFilterPredict(BizyAirBaseNode): +class BizyAir_DetailMethodPredict: @classmethod def INPUT_TYPES(cls): + + method_list = [ + "PyMatting", + "GuidedFilter", + ] return { "required": { "image": ("IMAGE", {}), "mask": ("MASK",), + "detail_method": (method_list,), "detail_erode": ( "INT", {"default": 6, "min": 1, "max": 255, "step": 1}, @@ -237,7 +245,7 @@ def INPUT_TYPES(cls): } CATEGORY = "☁️BizyAir/segment-anything" - # FUNCTION = "main" + FUNCTION = "main" RETURN_TYPES = ( "IMAGE", "MASK", @@ -246,55 +254,59 @@ def INPUT_TYPES(cls): "image", "mask", ) - NODE_DISPLAY_NAME = "☁️BizyAir GuidedFilter Predict" + def main( + self, + image, + mask, + detail_method, + detail_erode, + detail_dilate, + black_point, + white_point, + ): -class BizyAir_PyMattingPredict(BizyAirBaseNode): - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ("IMAGE", {}), - "mask": ("MASK",), - "detail_erode": ( - "INT", - {"default": 6, "min": 1, "max": 255, "step": 1}, - ), - "detail_dilate": ( - "INT", - {"default": 6, "min": 1, "max": 255, "step": 1}, - ), - "black_point": ( - "FLOAT", - { - "default": 0.15, - "min": 0.01, - "max": 0.98, - "step": 0.01, - "display": "slider", - }, - ), - "white_point": ( - "FLOAT", - { - "default": 0.99, - "min": 0.02, - "max": 0.99, - "step": 0.01, - "display": "slider", - }, - ), - } - } + ret_images = [] + ret_masks = [] + # device = comfy.model_management.get_torch_device() - CATEGORY = "☁️BizyAir/segment-anything" - # FUNCTION = "main" - RETURN_TYPES = ( - "IMAGE", - "MASK", - ) - RETURN_NAMES = ( - "image", - "mask", - ) - NODE_DISPLAY_NAME = "☁️BizyAir PyMatting Predict" + for i in range(image.shape[0]): + img = torch.unsqueeze(image[i], 0) + img = pil2tensor(tensor2pil(img).convert("RGB")) + _image = tensor2pil(img).convert("RGBA") + + detail_range = detail_erode + detail_dilate + if detail_method == "GuidedFilter": + _mask = guided_filter_alpha(img, mask[i], detail_range // 6 + 1) + _mask = tensor2pil(histogram_remap(_mask, black_point, white_point)) + + if detail_method == "PyMatting": + _mask = tensor2pil( + mask_edge_detail( + img, mask[i], detail_range // 8 + 1, black_point, white_point + ) + ) + + _image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L")) + + ret_images.append(pil2tensor(_image)) + ret_masks.append(image2mask(_mask)) + if len(ret_masks) == 0: + _, height, width, _ = image.size() + empty_mask = torch.zeros( + (1, height, width), dtype=torch.uint8, device="cpu" + ) + return (empty_mask, empty_mask) + + return ( + torch.cat(ret_images, dim=0), + torch.cat(ret_masks, dim=0), + ) + + +NODE_CLASS_MAPPINGS = { + "BizyAir_DetailMethodPredict": BizyAir_DetailMethodPredict, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "BizyAir_DetailMethodPredict": "☁️BizyAir DetailMethod Predict", +} diff --git a/examples/bizyair_segment_anything_ultra.json b/examples/bizyair_segment_anything_ultra.json index 75b13856..d0d25134 100644 --- a/examples/bizyair_segment_anything_ultra.json +++ b/examples/bizyair_segment_anything_ultra.json @@ -1,6 +1,6 @@ { - "last_node_id": 40, - "last_link_id": 76, + "last_node_id": 42, + "last_link_id": 82, "nodes": [ { "id": 20, @@ -67,44 +67,6 @@ "sam_vit_h (2.56GB)" ] }, - { - "id": 6, - "type": "LoadImage", - "pos": [ - -29.21724510192871, - 167.51760864257812 - ], - "size": [ - 315, - 314 - ], - "flags": {}, - "order": 2, - "mode": 0, - "inputs": [], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 38 - ], - "slot_index": 0 - }, - { - "name": "MASK", - "type": "MASK", - "links": null - } - ], - "properties": { - "Node name for S&R": "LoadImage" - }, - "widgets_values": [ - "0855.png", - "image" - ] - }, { "id": 24, "type": "MaskToImage", @@ -153,7 +115,7 @@ 246 ], "flags": {}, - "order": 12, + "order": 11, "mode": 0, "inputs": [ { @@ -180,7 +142,7 @@ 166 ], "flags": {}, - "order": 10, + "order": 9, "mode": 0, "inputs": [ { @@ -270,7 +232,7 @@ 78 ], "flags": {}, - "order": 3, + "order": 2, "mode": 0, "inputs": [], "outputs": [ @@ -310,7 +272,7 @@ 26 ], "flags": {}, - "order": 18, + "order": 15, "mode": 0, "inputs": [ { @@ -346,7 +308,7 @@ 246 ], "flags": {}, - "order": 22, + "order": 18, "mode": 0, "inputs": [ { @@ -373,7 +335,7 @@ 246 ], "flags": {}, - "order": 17, + "order": 14, "mode": 0, "inputs": [ { @@ -400,7 +362,7 @@ 26 ], "flags": {}, - "order": 11, + "order": 10, "mode": 0, "inputs": [ { @@ -436,7 +398,7 @@ 246 ], "flags": {}, - "order": 19, + "order": 16, "mode": 0, "inputs": [ { @@ -451,97 +413,6 @@ }, "widgets_values": [] }, - { - "id": 21, - "type": "BizyAir_GroundingDinoSAMSegment", - "pos": [ - 334.36334228515625, - -26.426876068115234 - ], - "size": [ - 286.7333068847656, - 146 - ], - "flags": {}, - "order": 4, - "mode": 0, - "inputs": [ - { - "name": "grounding_dino_model", - "type": "GROUNDING_DINO_MODEL", - "link": 35 - }, - { - "name": "sam_predictor", - "type": "SAM_PREDICTOR", - "link": 36 - }, - { - "name": "image", - "type": "IMAGE", - "link": 38 - } - ], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 40, - 43, - 56, - 72 - ], - "slot_index": 0 - }, - { - "name": "MASK", - "type": "MASK", - "links": [ - 42, - 51, - 57, - 71 - ], - "slot_index": 1 - } - ], - "properties": { - "Node name for S&R": "BizyAir_GroundingDinoSAMSegment" - }, - "widgets_values": [ - "girl", - 0.3, - 0.3 - ] - }, - { - "id": 38, - "type": "PreviewImage", - "pos": [ - 1439.570556640625, - -430.55450439453125 - ], - "size": [ - 210, - 246 - ], - "flags": {}, - "order": 15, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 74 - } - ], - "outputs": [], - "properties": { - "Node name for S&R": "PreviewImage" - }, - "widgets_values": [] - }, { "id": 22, "type": "BizyAir_TrimapGenerate", @@ -582,60 +453,6 @@ 6 ] }, - { - "id": 30, - "type": "BizyAir_PyMattingPredict", - "pos": [ - 746.6868286132812, - -9.079998016357422 - ], - "size": [ - 340.20001220703125, - 150 - ], - "flags": {}, - "order": 8, - "mode": 0, - "inputs": [ - { - "name": "image", - "type": "IMAGE", - "link": 56 - }, - { - "name": "mask", - "type": "MASK", - "link": 57 - } - ], - "outputs": [ - { - "name": "image", - "type": "IMAGE", - "links": [ - 60 - ], - "slot_index": 0 - }, - { - "name": "mask", - "type": "MASK", - "links": [ - 61 - ], - "slot_index": 1 - } - ], - "properties": { - "Node name for S&R": "BizyAir_PyMattingPredict" - }, - "widgets_values": [ - 3, - 6, - 0.15, - 0.99 - ] - }, { "id": 33, "type": "MaskToImage", @@ -648,13 +465,13 @@ 26 ], "flags": {}, - "order": 14, + "order": 13, "mode": 0, "inputs": [ { "name": "mask", "type": "MASK", - "link": 61 + "link": 80 } ], "outputs": [ @@ -684,7 +501,7 @@ 246 ], "flags": {}, - "order": 20, + "order": 17, "mode": 0, "inputs": [ { @@ -711,13 +528,13 @@ 246 ], "flags": {}, - "order": 13, + "order": 12, "mode": 0, "inputs": [ { "name": "images", "type": "IMAGE", - "link": 60 + "link": 79 } ], "outputs": [], @@ -727,121 +544,157 @@ "widgets_values": [] }, { - "id": 37, - "type": "PreviewImage", + "id": 6, + "type": "LoadImage", "pos": [ - 1194.40283203125, - -424.7430114746094 + -29.21724510192871, + 167.51760864257812 ], "size": [ - 210, - 246 + 315, + 314 ], "flags": {}, - "order": 21, + "order": 3, "mode": 0, - "inputs": [ + "inputs": [], + "outputs": [ { - "name": "images", + "name": "IMAGE", "type": "IMAGE", - "link": 76 + "links": [ + 38 + ], + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null } ], - "outputs": [], "properties": { - "Node name for S&R": "PreviewImage" + "Node name for S&R": "LoadImage" }, - "widgets_values": [] + "widgets_values": [ + "0854.png", + "image" + ] }, { - "id": 40, - "type": "BizyAir_GuidedFilterPredict", + "id": 21, + "type": "BizyAir_GroundingDinoSAMSegment", "pos": [ - 753.5757446289062, - -427.8138427734375 + 334.36334228515625, + -26.426876068115234 ], "size": [ - 378, - 150 + 286.7333068847656, + 146 ], "flags": {}, - "order": 9, + "order": 4, "mode": 0, "inputs": [ { - "name": "image", - "type": "IMAGE", - "link": 72 + "name": "grounding_dino_model", + "type": "GROUNDING_DINO_MODEL", + "link": 35 }, { - "name": "mask", - "type": "MASK", - "link": 71 + "name": "sam_predictor", + "type": "SAM_PREDICTOR", + "link": 36 + }, + { + "name": "image", + "type": "IMAGE", + "link": 38 } ], "outputs": [ { - "name": "image", + "name": "IMAGE", "type": "IMAGE", "links": [ - 74 + 40, + 43, + 77 ], "slot_index": 0 }, { - "name": "mask", + "name": "MASK", "type": "MASK", "links": [ - 75 + 42, + 51, + 78 ], "slot_index": 1 } ], "properties": { - "Node name for S&R": "BizyAir_GuidedFilterPredict" + "Node name for S&R": "BizyAir_GroundingDinoSAMSegment" }, "widgets_values": [ - 6, - 6, - 0.15, - 0.99 + "house", + 0.3, + 0.3 ] }, { - "id": 36, - "type": "MaskToImage", + "id": 42, + "type": "BizyAir_DetailMethodPredict", "pos": [ - 1170.7628173828125, - -506.9857482910156 + 743.06396484375, + 10.351668357849121 ], "size": [ - 214.6072540283203, - 26 + 340.20001220703125, + 174 ], "flags": {}, - "order": 16, + "order": 8, "mode": 0, "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 77 + }, { "name": "mask", "type": "MASK", - "link": 75 + "link": 78 } ], "outputs": [ { - "name": "IMAGE", + "name": "image", "type": "IMAGE", "links": [ - 76 - ], - "slot_index": 0 + 79 + ] + }, + { + "name": "mask", + "type": "MASK", + "links": [ + 80 + ] } ], "properties": { - "Node name for S&R": "MaskToImage" + "Node name for S&R": "BizyAir_DetailMethodPredict" }, - "widgets_values": [] + "widgets_values": [ + "GuidedFilter", + 6, + 6, + 0.15, + 0.99 + ] } ], "links": [ @@ -981,22 +834,6 @@ 0, "MASK" ], - [ - 56, - 21, - 0, - 30, - 0, - "IMAGE" - ], - [ - 57, - 21, - 1, - 30, - 1, - "MASK" - ], [ 59, 33, @@ -1006,60 +843,36 @@ "IMAGE" ], [ - 60, - 30, + 77, + 21, 0, - 34, + 42, 0, "IMAGE" ], [ - 61, - 30, - 1, - 33, - 0, - "MASK" - ], - [ - 71, + 78, 21, 1, - 40, + 42, 1, "MASK" ], [ - 72, - 21, - 0, - 40, - 0, - "IMAGE" - ], - [ - 74, - 40, + 79, + 42, 0, - 38, + 34, 0, "IMAGE" ], [ - 75, - 40, + 80, + 42, 1, - 36, + 33, 0, "MASK" - ], - [ - 76, - 36, - 0, - 37, - 0, - "IMAGE" ] ], "groups": [ @@ -1078,7 +891,7 @@ }, { "id": 2, - "title": "Group", + "title": "DetailMethod", "bounding": [ 697.36669921875, -120.6985092163086, @@ -1088,28 +901,15 @@ "color": "#3f789e", "font_size": 24, "flags": {} - }, - { - "id": 3, - "title": "Group", - "bounding": [ - 696.3088989257812, - -561.6924438476562, - 1012.4700927734375, - 393.0317077636719 - ], - "color": "#3f789e", - "font_size": 24, - "flags": {} } ], "config": {}, "extra": { "ds": { - "scale": 0.5209868481924371, + "scale": 0.6934334949441341, "offset": [ - -45.97107468187877, - 655.7454599034186 + 303.8549245275885, + 224.88420271042048 ] } }, From d784f4f50712696cf809f31c56e6e43b9d4a0c24 Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Mon, 6 Jan 2025 10:24:29 +0800 Subject: [PATCH 07/13] add code file --- .../nodes_segment_anything_utils.py | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 bizyair_extras/nodes_segment_anything_utils.py diff --git a/bizyair_extras/nodes_segment_anything_utils.py b/bizyair_extras/nodes_segment_anything_utils.py new file mode 100644 index 00000000..a81dde51 --- /dev/null +++ b/bizyair_extras/nodes_segment_anything_utils.py @@ -0,0 +1,173 @@ +import copy +from typing import List +from urllib.parse import urlparse + +import cv2 +import groundingdino.datasets.transforms as T +import numpy as np +import torch +from PIL import Image + +try: + from cv2.ximgproc import guidedFilter +except ImportError: + # print(e) + print( + f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'" + f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'." + f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m" + ) + +try: + from cv2.ximgproc import guidedFilter +except ImportError: + # print(e) + print( + f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'" + f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'." + f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m" + ) + + +sam_model_dir_name = "sams" +sam_model_list = { + "sam_vit_h (2.56GB)": { + "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + }, + # "sam_vit_l (1.25GB)": { + # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" + # }, + # "sam_vit_b (375MB)": { + # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" + # }, + # "sam_hq_vit_h (2.57GB)": { + # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" + # }, + # "sam_hq_vit_l (1.25GB)": { + # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" + # }, + # "sam_hq_vit_b (379MB)": { + # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" + # }, + # "mobile_sam(39MB)": { + # "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" + # }, +} + +groundingdino_model_dir_name = "grounding-dino" +groundingdino_model_list = { + "GroundingDINO_SwinT_OGC (694MB)": { + "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", + "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", + }, + # "GroundingDINO_SwinB (938MB)": { + # "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", + # "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", + # }, +} + + +def list_sam_model(): + return list(sam_model_list.keys()) + + +def list_groundingdino_model(): + return list(groundingdino_model_list.keys()) + + +def guided_filter_alpha( + image: torch.Tensor, mask: torch.Tensor, filter_radius: int +) -> torch.Tensor: + sigma = 0.15 + d = filter_radius + 1 + mask = pil2tensor(tensor2pil(mask).convert("RGB")) + if not bool(d % 2): + d += 1 + s = sigma / 10 + i_dup = copy.deepcopy(image.cpu().numpy()) + a_dup = copy.deepcopy(mask.cpu().numpy()) + for index, image in enumerate(i_dup): + alpha_work = a_dup[index] + i_dup[index] = guidedFilter(image, alpha_work, d, s) + return torch.from_numpy(i_dup) + + +def histogram_remap( + image: torch.Tensor, blackpoint: float, whitepoint: float +) -> torch.Tensor: + bp = min(blackpoint, whitepoint - 0.001) + scale = 1 / (whitepoint - bp) + i_dup = copy.deepcopy(image.cpu().numpy()) + i_dup = np.clip((i_dup - bp) * scale, 0.0, 1.0) + return torch.from_numpy(i_dup) + + +def mask_edge_detail( + image: torch.Tensor, + mask: torch.Tensor, + detail_range: int = 8, + black_point: float = 0.01, + white_point: float = 0.99, +) -> torch.Tensor: + from pymatting import estimate_alpha_cf, fix_trimap + + d = detail_range * 5 + 1 + mask = pil2tensor(tensor2pil(mask).convert("RGB")) + if not bool(d % 2): + d += 1 + i_dup = copy.deepcopy(image.cpu().numpy().astype(np.float64)) + a_dup = copy.deepcopy(mask.cpu().numpy().astype(np.float64)) + for index, img in enumerate(i_dup): + trimap = a_dup[index][:, :, 0] # convert to single channel + if detail_range > 0: + trimap = cv2.GaussianBlur(trimap, (d, d), 0) + trimap = fix_trimap(trimap, black_point, white_point) + alpha = estimate_alpha_cf( + img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500} + ) + a_dup[index] = np.stack([alpha, alpha, alpha], axis=-1) # convert back to rgb + return torch.from_numpy(a_dup.astype(np.float32)) + + +def pil2tensor(image: Image) -> torch.Tensor: + return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) + + +def tensor2pil(t_image: torch.Tensor) -> Image: + return Image.fromarray( + np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) + ) + + +def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]: + if len(tensor.shape) == 3: # Single image + return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8) + else: # Batch of images + return [ + np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor + ] + + +def mask2image(mask: torch.Tensor) -> Image: + masks = tensor2np(mask) + for m in masks: + _mask = Image.fromarray(m).convert("L") + _image = Image.new("RGBA", _mask.size, color="white") + _image = Image.composite( + _image, Image.new("RGBA", _mask.size, color="black"), _mask + ) + return _image + + +def image2mask(image: Image) -> torch.Tensor: + _image = image.convert("RGBA") + alpha = _image.split()[0] + bg = Image.new("L", _image.size) + _image = Image.merge("RGBA", (bg, bg, bg, alpha)) + ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()]) + return ret_mask + + +def RGB2RGBA(image: Image, mask: Image) -> Image: + (R, G, B) = image.convert("RGB").split() + return Image.merge("RGBA", (R, G, B, mask.convert("L"))) From c824b047c8afabb220b254fd87adf3db1ae02331 Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Mon, 6 Jan 2025 15:19:46 +0800 Subject: [PATCH 08/13] refine the code --- bizyair_extras/nodes_segment_anything.py | 48 ------------------- .../nodes_segment_anything_utils.py | 43 ++--------------- 2 files changed, 3 insertions(+), 88 deletions(-) diff --git a/bizyair_extras/nodes_segment_anything.py b/bizyair_extras/nodes_segment_anything.py index 9fe91d44..12e28850 100644 --- a/bizyair_extras/nodes_segment_anything.py +++ b/bizyair_extras/nodes_segment_anything.py @@ -4,50 +4,6 @@ from .nodes_segment_anything_utils import * -# sam_model_list = { -# "sam_vit_h (2.56GB)": { -# "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" -# }, -# # "sam_vit_l (1.25GB)": { -# # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" -# # }, -# # "sam_vit_b (375MB)": { -# # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" -# # }, -# # "sam_hq_vit_h (2.57GB)": { -# # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth" -# # }, -# # "sam_hq_vit_l (1.25GB)": { -# # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth" -# # }, -# # "sam_hq_vit_b (379MB)": { -# # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth" -# # }, -# # "mobile_sam(39MB)": { -# # "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt" -# # }, -# } - -# groundingdino_model_dir_name = "grounding-dino" -# groundingdino_model_list = { -# "GroundingDINO_SwinT_OGC (694MB)": { -# "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py", -# "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth", -# }, -# # "GroundingDINO_SwinB (938MB)": { -# # "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py", -# # "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth", -# # }, -# } - - -# def list_sam_model(): -# return list(sam_model_list.keys()) - - -# def list_groundingdino_model(): -# return list(groundingdino_model_list.keys()) - class BizyAir_SAMModelLoader(BizyAirBaseNode): @classmethod @@ -206,7 +162,6 @@ def INPUT_TYPES(cls): method_list = [ "PyMatting", - "GuidedFilter", ] return { "required": { @@ -276,9 +231,6 @@ def main( _image = tensor2pil(img).convert("RGBA") detail_range = detail_erode + detail_dilate - if detail_method == "GuidedFilter": - _mask = guided_filter_alpha(img, mask[i], detail_range // 6 + 1) - _mask = tensor2pil(histogram_remap(_mask, black_point, white_point)) if detail_method == "PyMatting": _mask = tensor2pil( diff --git a/bizyair_extras/nodes_segment_anything_utils.py b/bizyair_extras/nodes_segment_anything_utils.py index a81dde51..79ae8f16 100644 --- a/bizyair_extras/nodes_segment_anything_utils.py +++ b/bizyair_extras/nodes_segment_anything_utils.py @@ -2,32 +2,11 @@ from typing import List from urllib.parse import urlparse -import cv2 import groundingdino.datasets.transforms as T import numpy as np import torch from PIL import Image - -try: - from cv2.ximgproc import guidedFilter -except ImportError: - # print(e) - print( - f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'" - f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'." - f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m" - ) - -try: - from cv2.ximgproc import guidedFilter -except ImportError: - # print(e) - print( - f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'" - f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'." - f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m" - ) - +from scipy.ndimage import convolve, gaussian_filter sam_model_dir_name = "sams" sam_model_list = { @@ -75,23 +54,6 @@ def list_groundingdino_model(): return list(groundingdino_model_list.keys()) -def guided_filter_alpha( - image: torch.Tensor, mask: torch.Tensor, filter_radius: int -) -> torch.Tensor: - sigma = 0.15 - d = filter_radius + 1 - mask = pil2tensor(tensor2pil(mask).convert("RGB")) - if not bool(d % 2): - d += 1 - s = sigma / 10 - i_dup = copy.deepcopy(image.cpu().numpy()) - a_dup = copy.deepcopy(mask.cpu().numpy()) - for index, image in enumerate(i_dup): - alpha_work = a_dup[index] - i_dup[index] = guidedFilter(image, alpha_work, d, s) - return torch.from_numpy(i_dup) - - def histogram_remap( image: torch.Tensor, blackpoint: float, whitepoint: float ) -> torch.Tensor: @@ -120,7 +82,8 @@ def mask_edge_detail( for index, img in enumerate(i_dup): trimap = a_dup[index][:, :, 0] # convert to single channel if detail_range > 0: - trimap = cv2.GaussianBlur(trimap, (d, d), 0) + # trimap = cv2.GaussianBlur(trimap, (d, d), 0) + trimap = gaussian_filter(trimap, sigma=d / 2) trimap = fix_trimap(trimap, black_point, white_point) alpha = estimate_alpha_cf( img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500} From b1c975addb276117886bcb1beca8b7460f448c87 Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Mon, 6 Jan 2025 15:30:24 +0800 Subject: [PATCH 09/13] refine the code --- bizyair_extras/nodes_segment_anything_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bizyair_extras/nodes_segment_anything_utils.py b/bizyair_extras/nodes_segment_anything_utils.py index 79ae8f16..35dda6d2 100644 --- a/bizyair_extras/nodes_segment_anything_utils.py +++ b/bizyair_extras/nodes_segment_anything_utils.py @@ -1,12 +1,10 @@ import copy from typing import List -from urllib.parse import urlparse -import groundingdino.datasets.transforms as T import numpy as np import torch from PIL import Image -from scipy.ndimage import convolve, gaussian_filter +from scipy.ndimage import gaussian_filter sam_model_dir_name = "sams" sam_model_list = { From 74958af3c5ae4928f1ea53fbeb39effa5c976829 Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Mon, 6 Jan 2025 16:23:39 +0800 Subject: [PATCH 10/13] refine the code --- examples/bizyair_segment_anything_ultra.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/bizyair_segment_anything_ultra.json b/examples/bizyair_segment_anything_ultra.json index d0d25134..4f04e489 100644 --- a/examples/bizyair_segment_anything_ultra.json +++ b/examples/bizyair_segment_anything_ultra.json @@ -689,7 +689,7 @@ "Node name for S&R": "BizyAir_DetailMethodPredict" }, "widgets_values": [ - "GuidedFilter", + "PyMatting", 6, 6, 0.15, @@ -906,11 +906,11 @@ "config": {}, "extra": { "ds": { - "scale": 0.6934334949441341, - "offset": [ - 303.8549245275885, - 224.88420271042048 - ] + "scale": 0.6209213230591553, + "offset": { + "0": -12.436553955078125, + "1": 273.1443786621094 + } } }, "version": 0.4 From 58dfed9a4a159eb415233d8291042e1b2f650349 Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Mon, 6 Jan 2025 18:13:40 +0800 Subject: [PATCH 11/13] test --- examples/bizyair_segment_anything_ultra.json | 917 ------------------- 1 file changed, 917 deletions(-) delete mode 100644 examples/bizyair_segment_anything_ultra.json diff --git a/examples/bizyair_segment_anything_ultra.json b/examples/bizyair_segment_anything_ultra.json deleted file mode 100644 index 4f04e489..00000000 --- a/examples/bizyair_segment_anything_ultra.json +++ /dev/null @@ -1,917 +0,0 @@ -{ - "last_node_id": 42, - "last_link_id": 82, - "nodes": [ - { - "id": 20, - "type": "BizyAir_GroundingDinoModelLoader", - "pos": [ - -119.09739685058594, - -87.74108123779297 - ], - "size": [ - 415.8000183105469, - 58 - ], - "flags": {}, - "order": 0, - "mode": 0, - "inputs": [], - "outputs": [ - { - "name": "GROUNDING_DINO_MODEL", - "type": "GROUNDING_DINO_MODEL", - "links": [ - 35 - ], - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "BizyAir_GroundingDinoModelLoader" - }, - "widgets_values": [ - "GroundingDINO_SwinT_OGC (694MB)" - ] - }, - { - "id": 18, - "type": "BizyAir_SAMModelLoader", - "pos": [ - -24.793033599853516, - 40.01811218261719 - ], - "size": [ - 315, - 58 - ], - "flags": {}, - "order": 1, - "mode": 0, - "inputs": [], - "outputs": [ - { - "name": "SAM_PREDICTOR", - "type": "SAM_PREDICTOR", - "links": [ - 32, - 36 - ], - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "BizyAir_SAMModelLoader" - }, - "widgets_values": [ - "sam_vit_h (2.56GB)" - ] - }, - { - "id": 24, - "type": "MaskToImage", - "pos": [ - 330.49664306640625, - 165.41400146484375 - ], - "size": [ - 264.5999755859375, - 26 - ], - "flags": {}, - "order": 7, - "mode": 0, - "inputs": [ - { - "name": "mask", - "type": "MASK", - "link": 51 - } - ], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 50 - ], - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "MaskToImage" - }, - "widgets_values": [] - }, - { - "id": 25, - "type": "PreviewImage", - "pos": [ - 343.9561767578125, - 253.86276245117188 - ], - "size": [ - 210, - 246 - ], - "flags": {}, - "order": 11, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 50 - } - ], - "outputs": [], - "properties": { - "Node name for S&R": "PreviewImage" - }, - "widgets_values": [] - }, - { - "id": 23, - "type": "BizyAir_VITMattePredict", - "pos": [ - 1220.4493408203125, - 608.5986328125 - ], - "size": [ - 327.5999755859375, - 166 - ], - "flags": {}, - "order": 9, - "mode": 0, - "inputs": [ - { - "name": "image", - "type": "IMAGE", - "link": 43 - }, - { - "name": "trimap", - "type": "MASK", - "link": 46 - }, - { - "name": "vitmatte_model", - "type": "VitMatte_MODEL", - "link": 44 - }, - { - "name": "vitmatte_predictor", - "type": "VitMatte_predictor", - "link": 45 - } - ], - "outputs": [ - { - "name": "image", - "type": "IMAGE", - "links": [ - 48 - ], - "slot_index": 0 - }, - { - "name": "mask", - "type": "MASK", - "links": [ - 55 - ], - "slot_index": 1 - } - ], - "properties": { - "Node name for S&R": "BizyAir_VITMattePredict" - }, - "widgets_values": [ - 0.15, - 0.99, - 2 - ] - }, - { - "id": 4, - "type": "PreviewImage", - "pos": [ - 352.6379699707031, - 547.1114501953125 - ], - "size": [ - 210, - 246 - ], - "flags": {}, - "order": 5, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 40 - } - ], - "outputs": [], - "properties": { - "Node name for S&R": "PreviewImage" - }, - "widgets_values": [] - }, - { - "id": 19, - "type": "BizyAir_VITMatteModelLoader", - "pos": [ - 785.3675537109375, - 531.9297485351562 - ], - "size": [ - 365.4000244140625, - 78 - ], - "flags": {}, - "order": 2, - "mode": 0, - "inputs": [], - "outputs": [ - { - "name": "VitMatte_MODEL", - "type": "VitMatte_MODEL", - "links": [ - 44 - ], - "slot_index": 0 - }, - { - "name": "VitMatte_predictor", - "type": "VitMatte_predictor", - "links": [ - 45 - ], - "slot_index": 1 - } - ], - "properties": { - "Node name for S&R": "BizyAir_VITMatteModelLoader" - }, - "widgets_values": [ - "VITMatte" - ] - }, - { - "id": 28, - "type": "MaskToImage", - "pos": [ - 1612.9088134765625, - 588.6235961914062 - ], - "size": [ - 214.6072540283203, - 26 - ], - "flags": {}, - "order": 15, - "mode": 0, - "inputs": [ - { - "name": "mask", - "type": "MASK", - "link": 55 - } - ], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 54 - ], - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "MaskToImage" - }, - "widgets_values": [] - }, - { - "id": 29, - "type": "PreviewImage", - "pos": [ - 1620.6007080078125, - 678.995361328125 - ], - "size": [ - 210, - 246 - ], - "flags": {}, - "order": 18, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 54 - } - ], - "outputs": [], - "properties": { - "Node name for S&R": "PreviewImage" - }, - "widgets_values": [] - }, - { - "id": 9, - "type": "PreviewImage", - "pos": [ - 1872.0692138671875, - 677.0736083984375 - ], - "size": [ - 210, - 246 - ], - "flags": {}, - "order": 14, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 48 - } - ], - "outputs": [], - "properties": { - "Node name for S&R": "PreviewImage" - }, - "widgets_values": [] - }, - { - "id": 26, - "type": "MaskToImage", - "pos": [ - 810.7152709960938, - 827.2421875 - ], - "size": [ - 264.5999755859375, - 26 - ], - "flags": {}, - "order": 10, - "mode": 0, - "inputs": [ - { - "name": "mask", - "type": "MASK", - "link": 53 - } - ], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 52 - ], - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "MaskToImage" - }, - "widgets_values": [] - }, - { - "id": 27, - "type": "PreviewImage", - "pos": [ - 811.5810546875, - 920.4977416992188 - ], - "size": [ - 210, - 246 - ], - "flags": {}, - "order": 16, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 52 - } - ], - "outputs": [], - "properties": { - "Node name for S&R": "PreviewImage" - }, - "widgets_values": [] - }, - { - "id": 22, - "type": "BizyAir_TrimapGenerate", - "pos": [ - 803.268310546875, - 675.9127807617188 - ], - "size": [ - 315, - 82 - ], - "flags": {}, - "order": 6, - "mode": 0, - "inputs": [ - { - "name": "mask", - "type": "MASK", - "link": 42 - } - ], - "outputs": [ - { - "name": "trimap", - "type": "MASK", - "links": [ - 46, - 53 - ], - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "BizyAir_TrimapGenerate" - }, - "widgets_values": [ - 4, - 6 - ] - }, - { - "id": 33, - "type": "MaskToImage", - "pos": [ - 1125.576171875, - -39.265621185302734 - ], - "size": [ - 214.6072540283203, - 26 - ], - "flags": {}, - "order": 13, - "mode": 0, - "inputs": [ - { - "name": "mask", - "type": "MASK", - "link": 80 - } - ], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 59 - ], - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "MaskToImage" - }, - "widgets_values": [] - }, - { - "id": 31, - "type": "PreviewImage", - "pos": [ - 1133.45654296875, - 48.70017623901367 - ], - "size": [ - 210, - 246 - ], - "flags": {}, - "order": 17, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 59 - } - ], - "outputs": [], - "properties": { - "Node name for S&R": "PreviewImage" - }, - "widgets_values": [] - }, - { - "id": 34, - "type": "PreviewImage", - "pos": [ - 1377.463623046875, - 44.568546295166016 - ], - "size": [ - 210, - 246 - ], - "flags": {}, - "order": 12, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 79 - } - ], - "outputs": [], - "properties": { - "Node name for S&R": "PreviewImage" - }, - "widgets_values": [] - }, - { - "id": 6, - "type": "LoadImage", - "pos": [ - -29.21724510192871, - 167.51760864257812 - ], - "size": [ - 315, - 314 - ], - "flags": {}, - "order": 3, - "mode": 0, - "inputs": [], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 38 - ], - "slot_index": 0 - }, - { - "name": "MASK", - "type": "MASK", - "links": null - } - ], - "properties": { - "Node name for S&R": "LoadImage" - }, - "widgets_values": [ - "0854.png", - "image" - ] - }, - { - "id": 21, - "type": "BizyAir_GroundingDinoSAMSegment", - "pos": [ - 334.36334228515625, - -26.426876068115234 - ], - "size": [ - 286.7333068847656, - 146 - ], - "flags": {}, - "order": 4, - "mode": 0, - "inputs": [ - { - "name": "grounding_dino_model", - "type": "GROUNDING_DINO_MODEL", - "link": 35 - }, - { - "name": "sam_predictor", - "type": "SAM_PREDICTOR", - "link": 36 - }, - { - "name": "image", - "type": "IMAGE", - "link": 38 - } - ], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 40, - 43, - 77 - ], - "slot_index": 0 - }, - { - "name": "MASK", - "type": "MASK", - "links": [ - 42, - 51, - 78 - ], - "slot_index": 1 - } - ], - "properties": { - "Node name for S&R": "BizyAir_GroundingDinoSAMSegment" - }, - "widgets_values": [ - "house", - 0.3, - 0.3 - ] - }, - { - "id": 42, - "type": "BizyAir_DetailMethodPredict", - "pos": [ - 743.06396484375, - 10.351668357849121 - ], - "size": [ - 340.20001220703125, - 174 - ], - "flags": {}, - "order": 8, - "mode": 0, - "inputs": [ - { - "name": "image", - "type": "IMAGE", - "link": 77 - }, - { - "name": "mask", - "type": "MASK", - "link": 78 - } - ], - "outputs": [ - { - "name": "image", - "type": "IMAGE", - "links": [ - 79 - ] - }, - { - "name": "mask", - "type": "MASK", - "links": [ - 80 - ] - } - ], - "properties": { - "Node name for S&R": "BizyAir_DetailMethodPredict" - }, - "widgets_values": [ - "PyMatting", - 6, - 6, - 0.15, - 0.99 - ] - } - ], - "links": [ - [ - 32, - 18, - 0, - 17, - 1, - "SAM_PREDICTOR" - ], - [ - 35, - 20, - 0, - 21, - 0, - "GROUNDING_DINO_MODEL" - ], - [ - 36, - 18, - 0, - 21, - 1, - "SAM_PREDICTOR" - ], - [ - 38, - 6, - 0, - 21, - 2, - "IMAGE" - ], - [ - 40, - 21, - 0, - 4, - 0, - "IMAGE" - ], - [ - 42, - 21, - 1, - 22, - 0, - "MASK" - ], - [ - 43, - 21, - 0, - 23, - 0, - "IMAGE" - ], - [ - 44, - 19, - 0, - 23, - 2, - "VitMatte_MODEL" - ], - [ - 45, - 19, - 1, - 23, - 3, - "VitMatte_predictor" - ], - [ - 46, - 22, - 0, - 23, - 1, - "MASK" - ], - [ - 48, - 23, - 0, - 9, - 0, - "IMAGE" - ], - [ - 50, - 24, - 0, - 25, - 0, - "IMAGE" - ], - [ - 51, - 21, - 1, - 24, - 0, - "MASK" - ], - [ - 52, - 26, - 0, - 27, - 0, - "IMAGE" - ], - [ - 53, - 22, - 0, - 26, - 0, - "MASK" - ], - [ - 54, - 28, - 0, - 29, - 0, - "IMAGE" - ], - [ - 55, - 23, - 1, - 28, - 0, - "MASK" - ], - [ - 59, - 33, - 0, - 31, - 0, - "IMAGE" - ], - [ - 77, - 21, - 0, - 42, - 0, - "IMAGE" - ], - [ - 78, - 21, - 1, - 42, - 1, - "MASK" - ], - [ - 79, - 42, - 0, - 34, - 0, - "IMAGE" - ], - [ - 80, - 42, - 1, - 33, - 0, - "MASK" - ] - ], - "groups": [ - { - "id": 1, - "title": "VitMatte", - "bounding": [ - 696.8377685546875, - 384.8053894042969, - 1436.54345703125, - 805.4721069335938 - ], - "color": "#3f789e", - "font_size": 24, - "flags": {} - }, - { - "id": 2, - "title": "DetailMethod", - "bounding": [ - 697.36669921875, - -120.6985092163086, - 914.1190185546875, - 431.1031494140625 - ], - "color": "#3f789e", - "font_size": 24, - "flags": {} - } - ], - "config": {}, - "extra": { - "ds": { - "scale": 0.6209213230591553, - "offset": { - "0": -12.436553955078125, - "1": 273.1443786621094 - } - } - }, - "version": 0.4 -} From 92e80d37ea268166eadbd62a24f84f8a17c6b7e3 Mon Sep 17 00:00:00 2001 From: Wanghanying <2310016173@qq.com> Date: Mon, 6 Jan 2025 21:48:36 +0800 Subject: [PATCH 12/13] refine the code --- bizyair_extras/nodes_segment_anything.py | 6 +- examples/bizyair_segment_anything_ultra.json | 917 +++++++++++++++++++ 2 files changed, 920 insertions(+), 3 deletions(-) create mode 100644 examples/bizyair_segment_anything_ultra.json diff --git a/bizyair_extras/nodes_segment_anything.py b/bizyair_extras/nodes_segment_anything.py index 12e28850..b44f6fba 100644 --- a/bizyair_extras/nodes_segment_anything.py +++ b/bizyair_extras/nodes_segment_anything.py @@ -156,7 +156,7 @@ def INPUT_TYPES(cls): NODE_DISPLAY_NAME = "☁️BizyAir VITMatte Predict" -class BizyAir_DetailMethodPredict: +class BizyAirDetailMethodPredict: @classmethod def INPUT_TYPES(cls): @@ -257,8 +257,8 @@ def main( NODE_CLASS_MAPPINGS = { - "BizyAir_DetailMethodPredict": BizyAir_DetailMethodPredict, + "BizyAirDetailMethodPredict": BizyAirDetailMethodPredict, } NODE_DISPLAY_NAME_MAPPINGS = { - "BizyAir_DetailMethodPredict": "☁️BizyAir DetailMethod Predict", + "BizyAirDetailMethodPredict": "☁️BizyAir DetailMethod Predict", } diff --git a/examples/bizyair_segment_anything_ultra.json b/examples/bizyair_segment_anything_ultra.json new file mode 100644 index 00000000..34090040 --- /dev/null +++ b/examples/bizyair_segment_anything_ultra.json @@ -0,0 +1,917 @@ +{ + "last_node_id": 43, + "last_link_id": 86, + "nodes": [ + { + "id": 20, + "type": "BizyAir_GroundingDinoModelLoader", + "pos": [ + -119.09739685058594, + -87.74108123779297 + ], + "size": [ + 415.8000183105469, + 58 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "GROUNDING_DINO_MODEL", + "type": "GROUNDING_DINO_MODEL", + "links": [ + 35 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_GroundingDinoModelLoader" + }, + "widgets_values": [ + "GroundingDINO_SwinT_OGC (694MB)" + ] + }, + { + "id": 18, + "type": "BizyAir_SAMModelLoader", + "pos": [ + -24.793033599853516, + 40.01811218261719 + ], + "size": [ + 315, + 58 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "SAM_PREDICTOR", + "type": "SAM_PREDICTOR", + "links": [ + 32, + 36 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_SAMModelLoader" + }, + "widgets_values": [ + "sam_vit_h (2.56GB)" + ] + }, + { + "id": 24, + "type": "MaskToImage", + "pos": [ + 330.49664306640625, + 165.41400146484375 + ], + "size": [ + 264.5999755859375, + 26 + ], + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 51 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 50 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MaskToImage" + }, + "widgets_values": [] + }, + { + "id": 25, + "type": "PreviewImage", + "pos": [ + 343.9561767578125, + 253.86276245117188 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 50 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 23, + "type": "BizyAir_VITMattePredict", + "pos": [ + 1220.4493408203125, + 608.5986328125 + ], + "size": [ + 327.5999755859375, + 166 + ], + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 43 + }, + { + "name": "trimap", + "type": "MASK", + "link": 46 + }, + { + "name": "vitmatte_model", + "type": "VitMatte_MODEL", + "link": 44 + }, + { + "name": "vitmatte_predictor", + "type": "VitMatte_predictor", + "link": 45 + } + ], + "outputs": [ + { + "name": "image", + "type": "IMAGE", + "links": [ + 48 + ], + "slot_index": 0 + }, + { + "name": "mask", + "type": "MASK", + "links": [ + 55 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "BizyAir_VITMattePredict" + }, + "widgets_values": [ + 0.15, + 0.99, + 2 + ] + }, + { + "id": 4, + "type": "PreviewImage", + "pos": [ + 352.6379699707031, + 547.1114501953125 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 40 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 19, + "type": "BizyAir_VITMatteModelLoader", + "pos": [ + 785.3675537109375, + 531.9297485351562 + ], + "size": [ + 365.4000244140625, + 78 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "VitMatte_MODEL", + "type": "VitMatte_MODEL", + "links": [ + 44 + ], + "slot_index": 0 + }, + { + "name": "VitMatte_predictor", + "type": "VitMatte_predictor", + "links": [ + 45 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "BizyAir_VITMatteModelLoader" + }, + "widgets_values": [ + "VITMatte" + ] + }, + { + "id": 28, + "type": "MaskToImage", + "pos": [ + 1612.9088134765625, + 588.6235961914062 + ], + "size": [ + 214.6072540283203, + 26 + ], + "flags": {}, + "order": 15, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 55 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 54 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MaskToImage" + }, + "widgets_values": [] + }, + { + "id": 29, + "type": "PreviewImage", + "pos": [ + 1620.6007080078125, + 678.995361328125 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 18, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 54 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 9, + "type": "PreviewImage", + "pos": [ + 1872.0692138671875, + 677.0736083984375 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 48 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 26, + "type": "MaskToImage", + "pos": [ + 810.7152709960938, + 827.2421875 + ], + "size": [ + 264.5999755859375, + 26 + ], + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 53 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 52 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MaskToImage" + }, + "widgets_values": [] + }, + { + "id": 27, + "type": "PreviewImage", + "pos": [ + 811.5810546875, + 920.4977416992188 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 16, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 52 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 22, + "type": "BizyAir_TrimapGenerate", + "pos": [ + 803.268310546875, + 675.9127807617188 + ], + "size": [ + 315, + 82 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 42 + } + ], + "outputs": [ + { + "name": "trimap", + "type": "MASK", + "links": [ + 46, + 53 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "BizyAir_TrimapGenerate" + }, + "widgets_values": [ + 4, + 6 + ] + }, + { + "id": 33, + "type": "MaskToImage", + "pos": [ + 1125.576171875, + -39.265621185302734 + ], + "size": [ + 214.6072540283203, + 26 + ], + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 86 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 59 + ], + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "MaskToImage" + }, + "widgets_values": [] + }, + { + "id": 31, + "type": "PreviewImage", + "pos": [ + 1133.45654296875, + 48.70017623901367 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 17, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 59 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 34, + "type": "PreviewImage", + "pos": [ + 1377.463623046875, + 44.568546295166016 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 85 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "PreviewImage" + }, + "widgets_values": [] + }, + { + "id": 6, + "type": "LoadImage", + "pos": [ + -29.21724510192871, + 167.51760864257812 + ], + "size": [ + 315, + 314 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 38 + ], + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "0854.png", + "image" + ] + }, + { + "id": 21, + "type": "BizyAir_GroundingDinoSAMSegment", + "pos": [ + 334.36334228515625, + -26.426876068115234 + ], + "size": [ + 286.7333068847656, + 146 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "grounding_dino_model", + "type": "GROUNDING_DINO_MODEL", + "link": 35 + }, + { + "name": "sam_predictor", + "type": "SAM_PREDICTOR", + "link": 36 + }, + { + "name": "image", + "type": "IMAGE", + "link": 38 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 40, + 43, + 83 + ], + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": [ + 42, + 51, + 84 + ], + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "BizyAir_GroundingDinoSAMSegment" + }, + "widgets_values": [ + "house", + 0.3, + 0.3 + ] + }, + { + "id": 43, + "type": "BizyAirDetailMethodPredict", + "pos": [ + 733.4598999023438, + -11.248224258422852 + ], + "size": [ + 327.5999755859375, + 174 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 83 + }, + { + "name": "mask", + "type": "MASK", + "link": 84 + } + ], + "outputs": [ + { + "name": "image", + "type": "IMAGE", + "links": [ + 85 + ] + }, + { + "name": "mask", + "type": "MASK", + "links": [ + 86 + ] + } + ], + "properties": { + "Node name for S&R": "BizyAirDetailMethodPredict" + }, + "widgets_values": [ + "PyMatting", + 6, + 6, + 0.15, + 0.99 + ] + } + ], + "links": [ + [ + 32, + 18, + 0, + 17, + 1, + "SAM_PREDICTOR" + ], + [ + 35, + 20, + 0, + 21, + 0, + "GROUNDING_DINO_MODEL" + ], + [ + 36, + 18, + 0, + 21, + 1, + "SAM_PREDICTOR" + ], + [ + 38, + 6, + 0, + 21, + 2, + "IMAGE" + ], + [ + 40, + 21, + 0, + 4, + 0, + "IMAGE" + ], + [ + 42, + 21, + 1, + 22, + 0, + "MASK" + ], + [ + 43, + 21, + 0, + 23, + 0, + "IMAGE" + ], + [ + 44, + 19, + 0, + 23, + 2, + "VitMatte_MODEL" + ], + [ + 45, + 19, + 1, + 23, + 3, + "VitMatte_predictor" + ], + [ + 46, + 22, + 0, + 23, + 1, + "MASK" + ], + [ + 48, + 23, + 0, + 9, + 0, + "IMAGE" + ], + [ + 50, + 24, + 0, + 25, + 0, + "IMAGE" + ], + [ + 51, + 21, + 1, + 24, + 0, + "MASK" + ], + [ + 52, + 26, + 0, + 27, + 0, + "IMAGE" + ], + [ + 53, + 22, + 0, + 26, + 0, + "MASK" + ], + [ + 54, + 28, + 0, + 29, + 0, + "IMAGE" + ], + [ + 55, + 23, + 1, + 28, + 0, + "MASK" + ], + [ + 59, + 33, + 0, + 31, + 0, + "IMAGE" + ], + [ + 83, + 21, + 0, + 43, + 0, + "IMAGE" + ], + [ + 84, + 21, + 1, + 43, + 1, + "MASK" + ], + [ + 85, + 43, + 0, + 34, + 0, + "IMAGE" + ], + [ + 86, + 43, + 1, + 33, + 0, + "MASK" + ] + ], + "groups": [ + { + "id": 1, + "title": "VitMatte", + "bounding": [ + 696.8377685546875, + 384.8053894042969, + 1436.54345703125, + 805.4721069335938 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 2, + "title": "DetailMethod", + "bounding": [ + 697.36669921875, + -120.6985092163086, + 914.1190185546875, + 431.1031494140625 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + } + ], + "config": {}, + "extra": { + "ds": { + "scale": 0.7404627289582754, + "offset": [ + 23.72436698510137, + 379.85474857340284 + ] + } + }, + "version": 0.4 +} From 437dcf82c442f73dac35c3d16aa6615337028573 Mon Sep 17 00:00:00 2001 From: FengWen Date: Wed, 8 Jan 2025 10:47:10 +0800 Subject: [PATCH 13/13] sam_node_mount_route --- bizyair_extras/__init__.py | 8 -------- bizyair_extras/nodes_segment_anything.py | 14 +++----------- .../commands/processors/prompt_processor.py | 4 +++- src/bizyair/configs/models.yaml | 5 +++++ 4 files changed, 11 insertions(+), 20 deletions(-) diff --git a/bizyair_extras/__init__.py b/bizyair_extras/__init__.py index b641db39..418c4055 100644 --- a/bizyair_extras/__init__.py +++ b/bizyair_extras/__init__.py @@ -17,11 +17,3 @@ from .nodes_testing_utils import * from .nodes_ultimatesdupscale import * from .nodes_upscale_model import * - - -def update_mappings(module): - NODE_CLASS_MAPPINGS.update(**module.NODE_CLASS_MAPPINGS) - NODE_DISPLAY_NAME_MAPPINGS.update(**module.NODE_DISPLAY_NAME_MAPPINGS) - - -update_mappings(nodes_segment_anything) diff --git a/bizyair_extras/nodes_segment_anything.py b/bizyair_extras/nodes_segment_anything.py index b44f6fba..9d5f2048 100644 --- a/bizyair_extras/nodes_segment_anything.py +++ b/bizyair_extras/nodes_segment_anything.py @@ -1,5 +1,3 @@ -from urllib.parse import urlparse - from bizyair import BizyAirBaseNode from .nodes_segment_anything_utils import * @@ -156,7 +154,9 @@ def INPUT_TYPES(cls): NODE_DISPLAY_NAME = "☁️BizyAir VITMatte Predict" -class BizyAirDetailMethodPredict: +class BizyAirDetailMethodPredict(BizyAirBaseNode): + NODE_DISPLAY_NAME = "☁️BizyAir DetailMethod Predict" + @classmethod def INPUT_TYPES(cls): @@ -254,11 +254,3 @@ def main( torch.cat(ret_images, dim=0), torch.cat(ret_masks, dim=0), ) - - -NODE_CLASS_MAPPINGS = { - "BizyAirDetailMethodPredict": BizyAirDetailMethodPredict, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "BizyAirDetailMethodPredict": "☁️BizyAir DetailMethod Predict", -} diff --git a/src/bizyair/commands/processors/prompt_processor.py b/src/bizyair/commands/processors/prompt_processor.py index 3d6a248f..ec6020bc 100644 --- a/src/bizyair/commands/processors/prompt_processor.py +++ b/src/bizyair/commands/processors/prompt_processor.py @@ -77,7 +77,9 @@ def process(self, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]): if rule.base_model == base_model: if rule.score > out_score: out_route, out_score = rule.route, rule.score - + assert ( + out_route is not None + ), "Failed to find out_route, please check your prompt" return f"{BIZYAIR_SERVER_ADDRESS}{out_route}" def validate_input( diff --git a/src/bizyair/configs/models.yaml b/src/bizyair/configs/models.yaml index efede41f..9bf10825 100644 --- a/src/bizyair/configs/models.yaml +++ b/src/bizyair/configs/models.yaml @@ -274,3 +274,8 @@ model_rules: route: /supernode/bizyair-sam nodes: - class_type: 'LayerMask: SegmentAnythingUltra V2' + - class_type: SAMModelLoader + - class_type: TrimapGenerate + - class_type: VITMatteModelLoader + - class_type: DetailMethodPredict + - class_type: VitMattePredict