Skip to content

Commit

Permalink
seperate SegmentAnythingUltra V2 into nodes (#291)
Browse files Browse the repository at this point in the history
* seperate SegmentAnythingUltra V2 into nodes

* refine the code

* refine the code

* refine the code

* refine the code

* refine the code

* add code file

* refine the code

* refine the code

* refine the code

* test

* refine the code

* sam_node_mount_route

---------

Co-authored-by: FengWen <[email protected]>
Co-authored-by: FengWen <[email protected]>
  • Loading branch information
3 people authored Jan 8, 2025
1 parent 54199ea commit ced2dc4
Show file tree
Hide file tree
Showing 6 changed files with 1,316 additions and 1 deletion.
1 change: 1 addition & 0 deletions bizyair_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .nodes_kolors_mz import *
from .nodes_model_advanced import *
from .nodes_sd3 import *
from .nodes_segment_anything import *
from .nodes_testing_utils import *
from .nodes_ultimatesdupscale import *
from .nodes_upscale_model import *
256 changes: 256 additions & 0 deletions bizyair_extras/nodes_segment_anything.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
from bizyair import BizyAirBaseNode

from .nodes_segment_anything_utils import *


class BizyAir_SAMModelLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name": (list_sam_model(),),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = ("SAM_PREDICTOR",)
NODE_DISPLAY_NAME = "☁️BizyAir Load SAM Model"


class BizyAir_GroundingDinoModelLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name": (list_groundingdino_model(),),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = ("GROUNDING_DINO_MODEL",)
NODE_DISPLAY_NAME = "☁️BizyAir Load GroundingDino Model"


class BizyAir_VITMatteModelLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
method_list = [
"VITMatte",
"VITMatte(local)",
]
return {
"required": {
"detail_method": (method_list,),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = (
"VitMatte_MODEL",
"VitMatte_predictor",
)
NODE_DISPLAY_NAME = "☁️BizyAir Load VITMatte Model"


class BizyAir_GroundingDinoSAMSegment(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"grounding_dino_model": ("GROUNDING_DINO_MODEL", {}),
"sam_predictor": ("SAM_PREDICTOR", {}),
"image": ("IMAGE", {}),
"prompt": ("STRING", {}),
"box_threshold": (
"FLOAT",
{"default": 0.3, "min": 0, "max": 1.0, "step": 0.01},
),
"text_threshold": (
"FLOAT",
{"default": 0.3, "min": 0, "max": 1.0, "step": 0.01},
),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = ("IMAGE", "MASK")
NODE_DISPLAY_NAME = "☁️BizyAir GroundingDinoSAMSegment"


class BizyAir_TrimapGenerate(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"detail_erode": (
"INT",
{"default": 6, "min": 1, "max": 255, "step": 1},
),
"detail_dilate": (
"INT",
{"default": 6, "min": 1, "max": 255, "step": 1},
),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("trimap",)
NODE_DISPLAY_NAME = "☁️BizyAir Trimap Generate"


class BizyAir_VITMattePredict(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE", {}),
"trimap": ("MASK",),
"vitmatte_model": ("VitMatte_MODEL", {}),
"vitmatte_predictor": ("VitMatte_predictor", {}),
"black_point": (
"FLOAT",
{
"default": 0.15,
"min": 0.01,
"max": 0.98,
"step": 0.01,
"display": "slider",
},
),
"white_point": (
"FLOAT",
{
"default": 0.99,
"min": 0.02,
"max": 0.99,
"step": 0.01,
"display": "slider",
},
),
"max_megapixels": (
"FLOAT",
{"default": 2.0, "min": 1, "max": 999, "step": 0.1},
),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
# FUNCTION = "main"
RETURN_TYPES = (
"IMAGE",
"MASK",
)
RETURN_NAMES = (
"image",
"mask",
)
NODE_DISPLAY_NAME = "☁️BizyAir VITMatte Predict"


class BizyAirDetailMethodPredict(BizyAirBaseNode):
NODE_DISPLAY_NAME = "☁️BizyAir DetailMethod Predict"

@classmethod
def INPUT_TYPES(cls):

method_list = [
"PyMatting",
]
return {
"required": {
"image": ("IMAGE", {}),
"mask": ("MASK",),
"detail_method": (method_list,),
"detail_erode": (
"INT",
{"default": 6, "min": 1, "max": 255, "step": 1},
),
"detail_dilate": (
"INT",
{"default": 6, "min": 1, "max": 255, "step": 1},
),
"black_point": (
"FLOAT",
{
"default": 0.15,
"min": 0.01,
"max": 0.98,
"step": 0.01,
"display": "slider",
},
),
"white_point": (
"FLOAT",
{
"default": 0.99,
"min": 0.02,
"max": 0.99,
"step": 0.01,
"display": "slider",
},
),
}
}

CATEGORY = "☁️BizyAir/segment-anything"
FUNCTION = "main"
RETURN_TYPES = (
"IMAGE",
"MASK",
)
RETURN_NAMES = (
"image",
"mask",
)

def main(
self,
image,
mask,
detail_method,
detail_erode,
detail_dilate,
black_point,
white_point,
):

ret_images = []
ret_masks = []
# device = comfy.model_management.get_torch_device()

for i in range(image.shape[0]):
img = torch.unsqueeze(image[i], 0)
img = pil2tensor(tensor2pil(img).convert("RGB"))
_image = tensor2pil(img).convert("RGBA")

detail_range = detail_erode + detail_dilate

if detail_method == "PyMatting":
_mask = tensor2pil(
mask_edge_detail(
img, mask[i], detail_range // 8 + 1, black_point, white_point
)
)

_image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L"))

ret_images.append(pil2tensor(_image))
ret_masks.append(image2mask(_mask))
if len(ret_masks) == 0:
_, height, width, _ = image.size()
empty_mask = torch.zeros(
(1, height, width), dtype=torch.uint8, device="cpu"
)
return (empty_mask, empty_mask)

return (
torch.cat(ret_images, dim=0),
torch.cat(ret_masks, dim=0),
)
134 changes: 134 additions & 0 deletions bizyair_extras/nodes_segment_anything_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import copy
from typing import List

import numpy as np
import torch
from PIL import Image
from scipy.ndimage import gaussian_filter

sam_model_dir_name = "sams"
sam_model_list = {
"sam_vit_h (2.56GB)": {
"model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
},
# "sam_vit_l (1.25GB)": {
# "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"
# },
# "sam_vit_b (375MB)": {
# "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
# },
# "sam_hq_vit_h (2.57GB)": {
# "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth"
# },
# "sam_hq_vit_l (1.25GB)": {
# "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth"
# },
# "sam_hq_vit_b (379MB)": {
# "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth"
# },
# "mobile_sam(39MB)": {
# "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt"
# },
}

groundingdino_model_dir_name = "grounding-dino"
groundingdino_model_list = {
"GroundingDINO_SwinT_OGC (694MB)": {
"config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py",
"model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth",
},
# "GroundingDINO_SwinB (938MB)": {
# "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py",
# "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth",
# },
}


def list_sam_model():
return list(sam_model_list.keys())


def list_groundingdino_model():
return list(groundingdino_model_list.keys())


def histogram_remap(
image: torch.Tensor, blackpoint: float, whitepoint: float
) -> torch.Tensor:
bp = min(blackpoint, whitepoint - 0.001)
scale = 1 / (whitepoint - bp)
i_dup = copy.deepcopy(image.cpu().numpy())
i_dup = np.clip((i_dup - bp) * scale, 0.0, 1.0)
return torch.from_numpy(i_dup)


def mask_edge_detail(
image: torch.Tensor,
mask: torch.Tensor,
detail_range: int = 8,
black_point: float = 0.01,
white_point: float = 0.99,
) -> torch.Tensor:
from pymatting import estimate_alpha_cf, fix_trimap

d = detail_range * 5 + 1
mask = pil2tensor(tensor2pil(mask).convert("RGB"))
if not bool(d % 2):
d += 1
i_dup = copy.deepcopy(image.cpu().numpy().astype(np.float64))
a_dup = copy.deepcopy(mask.cpu().numpy().astype(np.float64))
for index, img in enumerate(i_dup):
trimap = a_dup[index][:, :, 0] # convert to single channel
if detail_range > 0:
# trimap = cv2.GaussianBlur(trimap, (d, d), 0)
trimap = gaussian_filter(trimap, sigma=d / 2)
trimap = fix_trimap(trimap, black_point, white_point)
alpha = estimate_alpha_cf(
img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500}
)
a_dup[index] = np.stack([alpha, alpha, alpha], axis=-1) # convert back to rgb
return torch.from_numpy(a_dup.astype(np.float32))


def pil2tensor(image: Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)


def tensor2pil(t_image: torch.Tensor) -> Image:
return Image.fromarray(
np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
)


def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]:
if len(tensor.shape) == 3: # Single image
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
else: # Batch of images
return [
np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor
]


def mask2image(mask: torch.Tensor) -> Image:
masks = tensor2np(mask)
for m in masks:
_mask = Image.fromarray(m).convert("L")
_image = Image.new("RGBA", _mask.size, color="white")
_image = Image.composite(
_image, Image.new("RGBA", _mask.size, color="black"), _mask
)
return _image


def image2mask(image: Image) -> torch.Tensor:
_image = image.convert("RGBA")
alpha = _image.split()[0]
bg = Image.new("L", _image.size)
_image = Image.merge("RGBA", (bg, bg, bg, alpha))
ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()])
return ret_mask


def RGB2RGBA(image: Image, mask: Image) -> Image:
(R, G, B) = image.convert("RGB").split()
return Image.merge("RGBA", (R, G, B, mask.convert("L")))
Loading

0 comments on commit ced2dc4

Please sign in to comment.