Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] onnx runtime for label anything #100

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
30 changes: 30 additions & 0 deletions label_anything/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,33 @@ When finished, we can get the model test visualization. On the left is the annot
With the semi-automated annotation function of Label-Studio, users can complete object segmentation and detection by simply clicking the mouse during the annotation process, greatly improving the efficiency of annotation.

Some of the code was borrowed from Pull Request ID 253 of label-studio-ml-backend. Thank you to the author for their contribution. Also, thanks to fellow community member [ATang0729](https://github.com/ATang0729) for re-labeling the meow dataset for script testing, and [JimmyMa99](https://github.com/JimmyMa99) for the conversion script, config template, and documentation Optimization.

## (beta)🚀 SAM backend inference using onnx runtime🚀 (optional)

We use onnx runtime for SAM back-end inference to improve the speed of SAM inference, tested on a 3090, which takes 4.6s with pytorch and 0.24s with onnx runtime.

First download the converted onnx from huggingface.

```shell
cd path/to/playground/label_anything
wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/encoder.onnx
wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/decoder.onnx
```

Then turn on back-end reasoning.

```shell
cd path/to/playground/label_anything

label-studio-ml start sam --port 8003 --with \
sam_config=vit_b \
sam_checkpoint_file=. /sam_vit_b_01ec64.pth \
out_mask=True \
out_bbox=True \
device=cuda:0 \
onnx=True \
# device=cuda:0 for GPU inference, if cpu inference is used, replace cuda:0 with cpu
# out_poly=True returns the annotation of the external polygon
```

⚠ Currently only sam_vit_b is supported.
30 changes: 30 additions & 0 deletions label_anything/readme_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,5 +384,35 @@ python tools/test.py data/my_set/mask-rcnn_r50_fpn.py path/of/your/checkpoint --

到此半自动化标注就完成了, 通过 Label-Studio 的半自动化标注功能,可以让用户在标注过程中,通过点击一下鼠标,就可以完成目标的分割和检测,大大提高了标注效率。部分代码借鉴自 label-studio-ml-backend ID 为 253 的 Pull Request,感谢作者的贡献。同时感谢社区同学 [ATang0729](https://github.com/ATang0729) 为脚本测试重新标注了喵喵数据集,以及 [JimmyMa99](https://github.com/JimmyMa99) 同学提供的转换脚本、 config 模板以及文档优化。

## (测试阶段)🚀使用 onnx runtime 进行 SAM 后端推理🚀(可选)

我们使用 onnx runtime 进行 SAM 后端推理以提升 SAM 的推理速度,在一张 3090 上测试,使用 pytorch 需要 4.6s ,使用 onnx runtime 只要 0.24s。

首先下载 huggingface 上转换好的 onnx。

```shell
cd path/to/playground/label_anything
wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/encoder.onnx
wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/decoder.onnx
#其他版本可以在 https://github.com/vietanhdev/anylabeling-assets/releases/tag/v0.2.0 下载
```

接着开启后端推理。

```shell
cd path/to/playground/label_anything

label-studio-ml start sam --port 8003 --with \
out_mask=True \
out_bbox=True \
device=cuda:0 \
onnx=True \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不加这个 onnx=True 就是 PyTorch 推理是吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对啊

onnx_encoder_file='encoder.onnx' \
onnx_decoder_file='decoder.onnx'
# device=cuda:0 为使用 GPU 推理,如果使用 cpu 推理,将 cuda:0 替换为 cpu
# out_poly=True 返回外接多边形的标注
```




228 changes: 194 additions & 34 deletions label_anything/sam/mmdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from label_studio_converter import brush
import torch
from torch.nn import functional as F

import cv2

Expand All @@ -19,10 +20,45 @@

# from mmdet.apis import inference_detector, init_detector
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
from segment_anything.utils.transforms import ResizeLongestSide
import random
import string
logger = logging.getLogger(__name__)


import onnxruntime
import time

def load_my_onnx(encoder_model_abs_path,decoder_model_abs_path):
# !wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/encoder.onnx
# !wget https://huggingface.co/visheratin/segment-anything-vit-b/resolve/main/decoder.onnx
# if onnx_config == 'vit_b':
# encoder_model_abs_path = "models/segment_anything_vit_b_encoder_quant.onnx"
# decoder_model_abs_path = "models/segment_anything_vit_b_decoder_quant.onnx"
# elif onnx_config == 'vit_l':
# encoder_model_abs_path = "models/segment_anything_vit_l_encoder_quant.onnx"
# decoder_model_abs_path = "models/segment_anything_vit_l_decoder_quant.onnx"
# elif onnx_config == 'vit_h':
# encoder_model_abs_path = "models/segment_anything_vit_h_encoder_quant.onnx"
# decoder_model_abs_path = "models/segment_anything_vit_h_decoder_quant.onnx"

providers = onnxruntime.get_available_providers()
if providers:
logging.info(
"Available providers for ONNXRuntime: %s", ", ".join(providers)
)
else:
logging.warning("No available providers for ONNXRuntime")
encoder_session = onnxruntime.InferenceSession(
encoder_model_abs_path, providers=providers
)
decoder_session = onnxruntime.InferenceSession(
decoder_model_abs_path, providers=providers
)

return encoder_session,decoder_session


def load_my_model(device="cuda:0",sam_config="vit_b",sam_checkpoint_file="sam_vit_b_01ec64.pth"):
"""
Loads the Segment Anything model on initializing Label studio, so if you call it outside MyModel it doesn't load every time you try to make a prediction
Expand Down Expand Up @@ -50,11 +86,18 @@ def __init__(self,
out_poly=False,
score_threshold=0.5,
device='cpu',
onnx=False,
onnx_encoder_file=None,
onnx_decoder_file=None,
**kwargs):

super(MMDetection, self).__init__(**kwargs)

PREDICTOR=load_my_model(device,sam_config,sam_checkpoint_file)
self.onnx=onnx
if self.onnx:
PREDICTOR=load_my_onnx(onnx_encoder_file,onnx_decoder_file)
else:
PREDICTOR=load_my_model(device,sam_config)
self.PREDICTOR = PREDICTOR

self.out_mask = out_mask
Expand Down Expand Up @@ -132,6 +175,79 @@ def __init__(self,
# self.model = init_detector(config_file, checkpoint_file, device=device)
self.score_thresh = score_threshold

####################################################################################################

def pre_process(self, image):
image_size = 1024
transform = ResizeLongestSide(image_size)

input_image = transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device="cpu")
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
x = (input_image_torch - pixel_mean) / pixel_std
h, w = x.shape[-2:]
padh = image_size - h
padw = image_size - w
x = F.pad(x, (0, padw, 0, padh))
x = x.numpy()

encoder_inputs = {
"x": x,
}
return encoder_inputs, image.shape[:2]

def run_encoder(self, encoder_inputs):
output = self.encoder_session.run(None, encoder_inputs)
image_embedding = output[0]
return image_embedding

def run_decoder(
self, image_embedding, input_prompt,img_size):
(original_height,original_width)=img_size
points=input_prompt['points']
masks=input_prompt['mask']
boxes=input_prompt['boxes']
labels=input_prompt['label']

image_size = 1024
transform = ResizeLongestSide(image_size)
if boxes is not None:
onnx_box_coords = boxes.reshape(2, 2)
input_labels = np.array([2,3])

onnx_coord = np.concatenate([onnx_box_coords, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[None, :].astype(np.float32)
elif points is not None:
input_point=points
input_label = np.array([1])
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

onnx_coord = transform.apply_coords(onnx_coord, img_size).astype(np.float32)

onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)


decoder_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(
img_size, dtype=np.float32
),
}
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
# masks = masks[0, 0, :, :] # Only get 1 mask
masks = masks > 0.0
# masks = masks.reshape(img_size)
return masks
##########################################################################################

def _get_image_url(self, task):
image_url = task['data'].get(
self.value) or task['data'].get(DATA_UNDEFINED_NAME)
Expand All @@ -155,9 +271,8 @@ def _get_image_url(self, task):
return image_url

def predict(self, tasks, **kwargs):

predictor = self.PREDICTOR

#共用区域
start = time.time()
results = []
assert len(tasks) == 1
task = tasks[0]
Expand All @@ -170,54 +285,99 @@ def predict(self, tasks, **kwargs):
# image = cv2.imread(f"./{split}")
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)

prompt_type = kwargs['context']['result'][0]['type']
original_height = kwargs['context']['result'][0]['original_height']
original_width = kwargs['context']['result'][0]['original_width']
#############################################
if self.onnx:
self.encoder_session,self.decoder_session=self.PREDICTOR
encoder_inputs,_ = self.pre_process(image)

input_prompt={}

input_prompt['boxes']=input_prompt['mask']=input_prompt['points']=input_prompt['label']=None
if prompt_type == 'keypointlabels':
# getting x and y coordinates of the keypoint
x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
output_label = kwargs['context']['result'][0]['value']['labels'][0]

input_prompt['points']=np.array([[x, y]])
input_prompt['label']=np.array([1])


if prompt_type == 'rectanglelabels':

x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
w = kwargs['context']['result'][0]['value']['width'] * original_width / 100
h = kwargs['context']['result'][0]['value']['height'] * original_height / 100

output_label = kwargs['context']['result'][0]['value']['rectanglelabels'][0]


input_prompt['boxes']=np.array([x, y, x+w, y+h])

input_prompt['label'] = np.array([2,3])


#encoder
image_embedding = self.run_encoder(encoder_inputs)
masks = self.run_decoder(image_embedding,input_prompt,\
(original_height,original_width))
masks = masks[0].astype(np.uint8)
# mask = masks.astype(np.uint8)
# shapes = self.post_process(masks, resized_ratio)

else:
predictor = self.PREDICTOR

if prompt_type == 'keypointlabels':
# getting x and y coordinates of the keypoint
x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
output_label = kwargs['context']['result'][0]['value']['labels'][0]

predictor.set_image(image)


masks, scores, logits = predictor.predict(
point_coords=np.array([[x, y]]),
# box=np.array([x.cpu() for x in bbox[:4]]),
point_labels=np.array([1]),
multimask_output=False,
)


if prompt_type == 'rectanglelabels':
if prompt_type == 'keypointlabels':
# getting x and y coordinates of the keypoint
x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
output_label = kwargs['context']['result'][0]['value']['labels'][0]


x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
w = kwargs['context']['result'][0]['value']['width'] * original_width / 100
h = kwargs['context']['result'][0]['value']['height'] * original_height / 100
masks, scores, logits = predictor.predict(
point_coords=np.array([[x, y]]),
# box=np.array([x.cpu() for x in bbox[:4]]),
point_labels=np.array([1]),
multimask_output=False,
)

output_label = kwargs['context']['result'][0]['value']['rectanglelabels'][0]

masks, scores, logits = predictor.predict(
# point_coords=np.array([[x, y]]),
box=np.array([x, y, x+w, y+h]),
point_labels=np.array([1]),
multimask_output=False,
)
if prompt_type == 'rectanglelabels':

x = kwargs['context']['result'][0]['value']['x'] * original_width / 100
y = kwargs['context']['result'][0]['value']['y'] * original_height / 100
w = kwargs['context']['result'][0]['value']['width'] * original_width / 100
h = kwargs['context']['result'][0]['value']['height'] * original_height / 100

mask = masks[0].astype(np.uint8) # each mask has shape [H, W]
# converting the mask from the model to RLE format which is usable in Label Studio
output_label = kwargs['context']['result'][0]['value']['rectanglelabels'][0]

# 找到轮廓
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
masks, scores, logits = predictor.predict(
# point_coords=np.array([[x, y]]),
box=np.array([x, y, x+w, y+h]),
point_labels=np.array([1]),
multimask_output=False,
)




# 找到轮廓
mask = masks[0].astype(np.uint8) # each mask has shape [H, W]
# converting the mask from the model to RLE format which is usable in Label Studio
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
end = time.time()
print(end-start)
########################

# 计算外接矩形

Expand Down