Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

seperate SegmentAnythingUltra V2 into nodes #291

Merged
merged 15 commits into from
Jan 8, 2025
1 change: 1 addition & 0 deletions bizyair_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .nodes_kolors_mz import *
from .nodes_model_advanced import *
from .nodes_sd3 import *
from .nodes_segment_anything import *
from .nodes_testing_utils import *
from .nodes_ultimatesdupscale import *
from .nodes_upscale_model import *
256 changes: 256 additions & 0 deletions bizyair_extras/nodes_segment_anything.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
from bizyair import BizyAirBaseNode

from .nodes_segment_anything_utils import *


class BizyAir_SAMModelLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name": (list_sam_model(),),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = ("SAM_PREDICTOR",)
NODE_DISPLAY_NAME = "☁️BizyAir Load SAM Model"


class BizyAir_GroundingDinoModelLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name": (list_groundingdino_model(),),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = ("GROUNDING_DINO_MODEL",)
NODE_DISPLAY_NAME = "☁️BizyAir Load GroundingDino Model"


class BizyAir_VITMatteModelLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
method_list = [
"VITMatte",
"VITMatte(local)",
]
return {
"required": {
"detail_method": (method_list,),
ccssu marked this conversation as resolved.
Show resolved Hide resolved
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = (
"VitMatte_MODEL",
"VitMatte_predictor",
)
NODE_DISPLAY_NAME = "☁️BizyAir Load VITMatte Model"


class BizyAir_GroundingDinoSAMSegment(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"grounding_dino_model": ("GROUNDING_DINO_MODEL", {}),
"sam_predictor": ("SAM_PREDICTOR", {}),
"image": ("IMAGE", {}),
"prompt": ("STRING", {}),
"box_threshold": (
"FLOAT",
{"default": 0.3, "min": 0, "max": 1.0, "step": 0.01},
),
"text_threshold": (
"FLOAT",
{"default": 0.3, "min": 0, "max": 1.0, "step": 0.01},
),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = ("IMAGE", "MASK")
NODE_DISPLAY_NAME = "☁️BizyAir GroundingDinoSAMSegment"


class BizyAir_TrimapGenerate(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"detail_erode": (
"INT",
{"default": 6, "min": 1, "max": 255, "step": 1},
),
"detail_dilate": (
"INT",
{"default": 6, "min": 1, "max": 255, "step": 1},
),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("trimap",)
NODE_DISPLAY_NAME = "☁️BizyAir Trimap Generate"


class BizyAir_VITMattePredict(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE", {}),
"trimap": ("MASK",),
"vitmatte_model": ("VitMatte_MODEL", {}),
"vitmatte_predictor": ("VitMatte_predictor", {}),
"black_point": (
"FLOAT",
{
"default": 0.15,
"min": 0.01,
"max": 0.98,
"step": 0.01,
"display": "slider",
},
),
"white_point": (
"FLOAT",
{
"default": 0.99,
"min": 0.02,
"max": 0.99,
"step": 0.01,
"display": "slider",
},
),
"max_megapixels": (
"FLOAT",
{"default": 2.0, "min": 1, "max": 999, "step": 0.1},
),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = (
"IMAGE",
"MASK",
)
RETURN_NAMES = (
"image",
"mask",
)
NODE_DISPLAY_NAME = "☁️BizyAir VITMatte Predict"


class BizyAirDetailMethodPredict(BizyAirBaseNode):
NODE_DISPLAY_NAME = "☁️BizyAir DetailMethod Predict"

@classmethod
def INPUT_TYPES(cls):

method_list = [
"PyMatting",
]
return {
"required": {
"image": ("IMAGE", {}),
"mask": ("MASK",),
"detail_method": (method_list,),
"detail_erode": (
"INT",
{"default": 6, "min": 1, "max": 255, "step": 1},
),
"detail_dilate": (
"INT",
{"default": 6, "min": 1, "max": 255, "step": 1},
),
"black_point": (
"FLOAT",
{
"default": 0.15,
"min": 0.01,
"max": 0.98,
"step": 0.01,
"display": "slider",
},
),
"white_point": (
"FLOAT",
{
"default": 0.99,
"min": 0.02,
"max": 0.99,
"step": 0.01,
"display": "slider",
},
),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
FUNCTION = "main"
RETURN_TYPES = (
"IMAGE",
"MASK",
)
RETURN_NAMES = (
"image",
"mask",
)

def main(
self,
image,
mask,
detail_method,
detail_erode,
detail_dilate,
black_point,
white_point,
):

ret_images = []
ret_masks = []
# device = comfy.model_management.get_torch_device()

for i in range(image.shape[0]):
img = torch.unsqueeze(image[i], 0)
img = pil2tensor(tensor2pil(img).convert("RGB"))
_image = tensor2pil(img).convert("RGBA")

detail_range = detail_erode + detail_dilate

if detail_method == "PyMatting":
_mask = tensor2pil(
mask_edge_detail(
img, mask[i], detail_range // 8 + 1, black_point, white_point
)
)

_image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L"))

ret_images.append(pil2tensor(_image))
ret_masks.append(image2mask(_mask))
if len(ret_masks) == 0:
_, height, width, _ = image.size()
empty_mask = torch.zeros(
(1, height, width), dtype=torch.uint8, device="cpu"
)
return (empty_mask, empty_mask)

return (
torch.cat(ret_images, dim=0),
torch.cat(ret_masks, dim=0),
)
134 changes: 134 additions & 0 deletions bizyair_extras/nodes_segment_anything_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import copy
from typing import List

import numpy as np
import torch
from PIL import Image
from scipy.ndimage import gaussian_filter

sam_model_dir_name = "sams"
sam_model_list = {
"sam_vit_h (2.56GB)": {
"model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
},
# "sam_vit_l (1.25GB)": {
# "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"
# },
# "sam_vit_b (375MB)": {
# "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
# },
# "sam_hq_vit_h (2.57GB)": {
# "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth"
# },
# "sam_hq_vit_l (1.25GB)": {
# "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth"
# },
# "sam_hq_vit_b (379MB)": {
# "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth"
# },
# "mobile_sam(39MB)": {
# "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt"
# },
}

groundingdino_model_dir_name = "grounding-dino"
groundingdino_model_list = {
"GroundingDINO_SwinT_OGC (694MB)": {
"config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py",
"model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth",
},
# "GroundingDINO_SwinB (938MB)": {
# "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py",
# "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth",
# },
}


def list_sam_model():
return list(sam_model_list.keys())


def list_groundingdino_model():
return list(groundingdino_model_list.keys())


def histogram_remap(
image: torch.Tensor, blackpoint: float, whitepoint: float
) -> torch.Tensor:
bp = min(blackpoint, whitepoint - 0.001)
scale = 1 / (whitepoint - bp)
i_dup = copy.deepcopy(image.cpu().numpy())
i_dup = np.clip((i_dup - bp) * scale, 0.0, 1.0)
return torch.from_numpy(i_dup)


def mask_edge_detail(
image: torch.Tensor,
mask: torch.Tensor,
detail_range: int = 8,
black_point: float = 0.01,
white_point: float = 0.99,
) -> torch.Tensor:
from pymatting import estimate_alpha_cf, fix_trimap

d = detail_range * 5 + 1
mask = pil2tensor(tensor2pil(mask).convert("RGB"))
if not bool(d % 2):
d += 1
i_dup = copy.deepcopy(image.cpu().numpy().astype(np.float64))
a_dup = copy.deepcopy(mask.cpu().numpy().astype(np.float64))
for index, img in enumerate(i_dup):
trimap = a_dup[index][:, :, 0] # convert to single channel
if detail_range > 0:
# trimap = cv2.GaussianBlur(trimap, (d, d), 0)
trimap = gaussian_filter(trimap, sigma=d / 2)
trimap = fix_trimap(trimap, black_point, white_point)
alpha = estimate_alpha_cf(
img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500}
)
a_dup[index] = np.stack([alpha, alpha, alpha], axis=-1) # convert back to rgb
return torch.from_numpy(a_dup.astype(np.float32))


def pil2tensor(image: Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)


def tensor2pil(t_image: torch.Tensor) -> Image:
return Image.fromarray(
np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
)


def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]:
if len(tensor.shape) == 3: # Single image
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
else: # Batch of images
return [
np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor
]


def mask2image(mask: torch.Tensor) -> Image:
masks = tensor2np(mask)
for m in masks:
_mask = Image.fromarray(m).convert("L")
_image = Image.new("RGBA", _mask.size, color="white")
_image = Image.composite(
_image, Image.new("RGBA", _mask.size, color="black"), _mask
)
return _image


def image2mask(image: Image) -> torch.Tensor:
_image = image.convert("RGBA")
alpha = _image.split()[0]
bg = Image.new("L", _image.size)
_image = Image.merge("RGBA", (bg, bg, bg, alpha))
ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()])
return ret_mask


def RGB2RGBA(image: Image, mask: Image) -> Image:
(R, G, B) = image.convert("RGB").split()
return Image.merge("RGBA", (R, G, B, mask.convert("L")))
Loading
Loading