Skip to content

Commit

Permalink
refine the code
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshier108 committed Jan 6, 2025
1 parent d784f4f commit c824b04
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 88 deletions.
48 changes: 0 additions & 48 deletions bizyair_extras/nodes_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,6 @@

from .nodes_segment_anything_utils import *

# sam_model_list = {
# "sam_vit_h (2.56GB)": {
# "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
# },
# # "sam_vit_l (1.25GB)": {
# # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"
# # },
# # "sam_vit_b (375MB)": {
# # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
# # },
# # "sam_hq_vit_h (2.57GB)": {
# # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth"
# # },
# # "sam_hq_vit_l (1.25GB)": {
# # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth"
# # },
# # "sam_hq_vit_b (379MB)": {
# # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth"
# # },
# # "mobile_sam(39MB)": {
# # "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt"
# # },
# }

# groundingdino_model_dir_name = "grounding-dino"
# groundingdino_model_list = {
# "GroundingDINO_SwinT_OGC (694MB)": {
# "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py",
# "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth",
# },
# # "GroundingDINO_SwinB (938MB)": {
# # "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py",
# # "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth",
# # },
# }


# def list_sam_model():
# return list(sam_model_list.keys())


# def list_groundingdino_model():
# return list(groundingdino_model_list.keys())


class BizyAir_SAMModelLoader(BizyAirBaseNode):
@classmethod
Expand Down Expand Up @@ -206,7 +162,6 @@ def INPUT_TYPES(cls):

method_list = [
"PyMatting",
"GuidedFilter",
]
return {
"required": {
Expand Down Expand Up @@ -276,9 +231,6 @@ def main(
_image = tensor2pil(img).convert("RGBA")

detail_range = detail_erode + detail_dilate
if detail_method == "GuidedFilter":
_mask = guided_filter_alpha(img, mask[i], detail_range // 6 + 1)
_mask = tensor2pil(histogram_remap(_mask, black_point, white_point))

if detail_method == "PyMatting":
_mask = tensor2pil(
Expand Down
43 changes: 3 additions & 40 deletions bizyair_extras/nodes_segment_anything_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,11 @@
from typing import List
from urllib.parse import urlparse

import cv2
import groundingdino.datasets.transforms as T
import numpy as np
import torch
from PIL import Image

try:
from cv2.ximgproc import guidedFilter
except ImportError:
# print(e)
print(
f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'"
f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'."
f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m"
)

try:
from cv2.ximgproc import guidedFilter
except ImportError:
# print(e)
print(
f"Cannot import name 'guidedFilter' from 'cv2.ximgproc'"
f"\nA few nodes cannot works properly, while most nodes are not affected. Please REINSTALL package 'opencv-contrib-python'."
f"\nFor detail refer to \033[4mhttps://github.com/chflame163/ComfyUI_LayerStyle/issues/5\033[0m"
)

from scipy.ndimage import convolve, gaussian_filter

sam_model_dir_name = "sams"
sam_model_list = {
Expand Down Expand Up @@ -75,23 +54,6 @@ def list_groundingdino_model():
return list(groundingdino_model_list.keys())


def guided_filter_alpha(
image: torch.Tensor, mask: torch.Tensor, filter_radius: int
) -> torch.Tensor:
sigma = 0.15
d = filter_radius + 1
mask = pil2tensor(tensor2pil(mask).convert("RGB"))
if not bool(d % 2):
d += 1
s = sigma / 10
i_dup = copy.deepcopy(image.cpu().numpy())
a_dup = copy.deepcopy(mask.cpu().numpy())
for index, image in enumerate(i_dup):
alpha_work = a_dup[index]
i_dup[index] = guidedFilter(image, alpha_work, d, s)
return torch.from_numpy(i_dup)


def histogram_remap(
image: torch.Tensor, blackpoint: float, whitepoint: float
) -> torch.Tensor:
Expand Down Expand Up @@ -120,7 +82,8 @@ def mask_edge_detail(
for index, img in enumerate(i_dup):
trimap = a_dup[index][:, :, 0] # convert to single channel
if detail_range > 0:
trimap = cv2.GaussianBlur(trimap, (d, d), 0)
# trimap = cv2.GaussianBlur(trimap, (d, d), 0)
trimap = gaussian_filter(trimap, sigma=d / 2)
trimap = fix_trimap(trimap, black_point, white_point)
alpha = estimate_alpha_cf(
img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500}
Expand Down

0 comments on commit c824b04

Please sign in to comment.