Skip to content

Commit

Permalink
Implement EdgeSAM onnx model for edge devices
Browse files Browse the repository at this point in the history
@Article{zhou2023edgesam,
  title={EdgeSAM: Prompt-In-the-Loop Distillation for On-Device Deployment of SAM},
  author={Zhou, Chong and Li, Xiangtai and Loy, Chen Change and Dai, Bo},
  journal={arXiv preprint arXiv:2312.06660},
  year={2023}
}
  • Loading branch information
healthonrails committed Dec 16, 2023
1 parent 812bd25 commit b61c22b
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 23 deletions.
14 changes: 14 additions & 0 deletions annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
from annolid.postprocessing.quality_control import pred_dict_to_labelme
from annolid.annotation.keypoints import save_labels
from annolid.annotation.timestamps import convert_frame_number_to_time
from annolid.segmentation.SAM import MODELS

__appname__ = 'Annolid'
__version__ = "1.1.3"

Expand Down Expand Up @@ -452,6 +454,7 @@ def __init__(self,

self.actions.tool = tuple(_action_tools)
self.tools.clear()

utils.addActions(self.tools, self.actions.tool)
utils.addActions(self.menus.frames, (frames,))
utils.addActions(self.menus.open_video, (open_video, open_audio))
Expand Down Expand Up @@ -482,6 +485,17 @@ def __init__(self,
# Callbacks:
self.zoomWidget.valueChanged.connect(self.paintCanvas)

self._selectAiModelComboBox.clear()
self._selectAiModelComboBox.addItems([model.name for model in MODELS])
self._selectAiModelComboBox.setCurrentIndex(1)
self._selectAiModelComboBox.currentIndexChanged.connect(
lambda: self.canvas.initializeAiModel(
name=self._selectAiModelComboBox.currentText()
)
if self.canvas.createMode in ["ai_polygon", "ai_mask"]
else None
)

self.populateModeActions()

def update_step_size(self, value):
Expand Down
18 changes: 9 additions & 9 deletions annolid/gui/widgets/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import imgviz
import labelme.ai
from labelme.logger import logger

from annolid.segmentation import SAM

# TODO(unknown):
# - [maybe] Find optimal epsilon value.
Expand Down Expand Up @@ -156,15 +156,15 @@ def createMode(self, value):
self.current = None

def initializeAiModel(self, name):
if name not in [model.name for model in labelme.ai.MODELS]:
if name not in [model.name for model in SAM.MODELS]:
raise ValueError("Unsupported ai model: %s" % name)
model = [model for model in labelme.ai.MODELS if model.name == name][0]
model = [model for model in SAM.MODELS if model.name == name][0]

if self._ai_model is not None and self._ai_model.name == model.name:
logger.debug("AI model is already initialized: %r" % model.name)
else:
logger.debug("Initializing AI model: %r" % model.name)
self._ai_model = labelme.ai.SegmentAnythingModel(
self._ai_model = SAM.SegmentAnythingModel(
name=model.name,
encoder_path=gdown.cached_download(
url=model.encoder_weight.url,
Expand Down Expand Up @@ -568,7 +568,7 @@ def mousePressEvent(self, ev):
self.current.shape_type = "circle"
self.line.points = [pos, pos]
if (
self.createMode in ["ai_polygon","ai_mask"]
self.createMode in ["ai_polygon", "ai_mask"]
and is_shift_pressed
):
self.line.point_labels = [0, 0]
Expand Down Expand Up @@ -845,10 +845,10 @@ def paintEvent(self, event):

# draw crosshair
if (self._crosshair[self._createMode]
and self.drawing()
and self.prevMovePoint
and not self.outOfPixmap(self.prevMovePoint)
):
and self.drawing()
and self.prevMovePoint
and not self.outOfPixmap(self.prevMovePoint)
):
p.setPen(QtGui.QColor(0, 0, 0))
p.drawLine(
0,
Expand Down
11 changes: 11 additions & 0 deletions annolid/segmentation/SAM/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
Weight = collections.namedtuple("Weight", ["url", "md5"])

MODELS = [
Model(
name="Segment-Anything (Edge)",
encoder_weight=Weight(
url="https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam_3x_encoder.onnx", # NOQA
md5="e0745d06f3ee9c5e01a667b56a40875b",
),
decoder_weight=Weight(
url="https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam_3x_decoder.onnx", # NOQA
md5="9fe1d5521b4349ab710e9cc970936970",
),
),
Model(
name="Segment-Anything (speed)",
encoder_weight=Weight(
Expand Down
88 changes: 74 additions & 14 deletions annolid/segmentation/SAM/segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import onnxruntime
import PIL.Image
import skimage.measure

import cv2
from labelme.logger import logger


Expand Down Expand Up @@ -72,6 +72,18 @@ def predict_polygon_from_points(self, points, point_labels):
)
return polygon

def predict_mask_from_points(self, points, point_labels):
image_embedding = self._get_image_embedding()
mask = _compute_mask_from_points(
image_size=self._image_size,
decoder_session=self._decoder_session,
image=self._image,
image_embedding=image_embedding,
points=points,
point_labels=point_labels,
)
return mask


def _compute_scale_to_resize_image(image_size, image):
height, width = image.shape[:2]
Expand Down Expand Up @@ -99,6 +111,17 @@ def _resize_image(image_size, image):
return scale, scaled_image


def postprocess_masks(mask, img_size, input_size, original_size):
mask = mask.squeeze(0).transpose(1, 2, 0)
mask = cv2.resize(mask, (img_size, img_size),
interpolation=cv2.INTER_LINEAR)
mask = mask[:input_size[0], :input_size[1], :]
mask = cv2.resize(
mask, (original_size[1], original_size[0]), interpolation=cv2.INTER_LINEAR)
mask = mask.transpose(2, 0, 1)[None, :, :, :]
return mask


def _compute_image_embedding(image_size, encoder_session, image):
image = imgviz.asrgb(image)

Expand All @@ -115,8 +138,12 @@ def _compute_image_embedding(image_size, encoder_session, image):
),
)
x = x.transpose(2, 0, 1)[None, :, :, :]

output = encoder_session.run(output_names=None, input_feed={"x": x})
input_names = [input.name for input in encoder_session.get_inputs()]
if input_names[0] == 'image':
output = encoder_session.run(
output_names=None, input_feed={"image": x})
else:
output = encoder_session.run(output_names=None, input_feed={"x": x})
image_embedding = output[0]

return image_embedding
Expand All @@ -128,7 +155,7 @@ def _get_contour_length(contour):
return np.linalg.norm(contour_end - contour_start, axis=1).sum()


def _compute_polygon_from_points(
def _compute_mask_from_points(
image_size, decoder_session, image, image_embedding, points, point_labels
):
input_point = np.array(points, dtype=np.float32)
Expand All @@ -152,28 +179,61 @@ def _compute_polygon_from_points(
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.array([-1], dtype=np.float32)

decoder_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(image.shape[:2], dtype=np.float32),
}
input_names = [input.name for input in decoder_session.get_inputs()]
if len(input_names) <= 3:
outputs = decoder_session.run(None, {
'image_embeddings': image_embedding,
'point_coords': onnx_coord,
'point_labels': onnx_label,
})
scores, masks = outputs
masks = postprocess_masks(
masks, image_size, (new_height, new_width), np.array(image.shape[:2]))

masks, _, _ = decoder_session.run(None, decoder_inputs)
else:
decoder_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(image.shape[:2], dtype=np.float32),
}

masks, _, _ = decoder_session.run(None, decoder_inputs)
mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
mask = mask > 0.0

MIN_SIZE_RATIO = 0.05
skimage.morphology.remove_small_objects(
mask, min_size=mask.sum() * MIN_SIZE_RATIO, out=mask
)

if 0:
imgviz.io.imsave(
"mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
)
return mask


def _compute_polygon_from_points(
image_size, decoder_session, image, image_embedding, points, point_labels
):
mask = _compute_mask_from_points(
image_size=image_size,
decoder_session=decoder_session,
image=image,
image_embedding=image_embedding,
points=points,
point_labels=point_labels,
)

contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
contour = max(contours, key=_get_contour_length)
POLYGON_APPROX_TOLERANCE = 0.004
polygon = skimage.measure.approximate_polygon(
coords=contour,
tolerance=np.ptp(contour, axis=0).max() / 100,
tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE,
)
polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1))
polygon = polygon[:-1] # drop last point that is duplicate of first point
Expand Down

0 comments on commit b61c22b

Please sign in to comment.