From b61c22bf6c964ecab69add8c9b5bdc22c05b1691 Mon Sep 17 00:00:00 2001 From: healthonrails Date: Fri, 15 Dec 2023 19:45:29 -0500 Subject: [PATCH] Implement EdgeSAM onnx model for edge devices @article{zhou2023edgesam, title={EdgeSAM: Prompt-In-the-Loop Distillation for On-Device Deployment of SAM}, author={Zhou, Chong and Li, Xiangtai and Loy, Chen Change and Dai, Bo}, journal={arXiv preprint arXiv:2312.06660}, year={2023} } --- annolid/gui/app.py | 14 ++++ annolid/gui/widgets/canvas.py | 18 ++-- annolid/segmentation/SAM/__init__.py | 11 +++ annolid/segmentation/SAM/segment_anything.py | 88 ++++++++++++++++---- 4 files changed, 108 insertions(+), 23 deletions(-) diff --git a/annolid/gui/app.py b/annolid/gui/app.py index 365c05c7..2ccdf28c 100644 --- a/annolid/gui/app.py +++ b/annolid/gui/app.py @@ -55,6 +55,8 @@ from annolid.postprocessing.quality_control import pred_dict_to_labelme from annolid.annotation.keypoints import save_labels from annolid.annotation.timestamps import convert_frame_number_to_time +from annolid.segmentation.SAM import MODELS + __appname__ = 'Annolid' __version__ = "1.1.3" @@ -452,6 +454,7 @@ def __init__(self, self.actions.tool = tuple(_action_tools) self.tools.clear() + utils.addActions(self.tools, self.actions.tool) utils.addActions(self.menus.frames, (frames,)) utils.addActions(self.menus.open_video, (open_video, open_audio)) @@ -482,6 +485,17 @@ def __init__(self, # Callbacks: self.zoomWidget.valueChanged.connect(self.paintCanvas) + self._selectAiModelComboBox.clear() + self._selectAiModelComboBox.addItems([model.name for model in MODELS]) + self._selectAiModelComboBox.setCurrentIndex(1) + self._selectAiModelComboBox.currentIndexChanged.connect( + lambda: self.canvas.initializeAiModel( + name=self._selectAiModelComboBox.currentText() + ) + if self.canvas.createMode in ["ai_polygon", "ai_mask"] + else None + ) + self.populateModeActions() def update_step_size(self, value): diff --git a/annolid/gui/widgets/canvas.py b/annolid/gui/widgets/canvas.py index 35ee515d..7609b94c 100644 --- a/annolid/gui/widgets/canvas.py +++ b/annolid/gui/widgets/canvas.py @@ -12,7 +12,7 @@ import imgviz import labelme.ai from labelme.logger import logger - +from annolid.segmentation import SAM # TODO(unknown): # - [maybe] Find optimal epsilon value. @@ -156,15 +156,15 @@ def createMode(self, value): self.current = None def initializeAiModel(self, name): - if name not in [model.name for model in labelme.ai.MODELS]: + if name not in [model.name for model in SAM.MODELS]: raise ValueError("Unsupported ai model: %s" % name) - model = [model for model in labelme.ai.MODELS if model.name == name][0] + model = [model for model in SAM.MODELS if model.name == name][0] if self._ai_model is not None and self._ai_model.name == model.name: logger.debug("AI model is already initialized: %r" % model.name) else: logger.debug("Initializing AI model: %r" % model.name) - self._ai_model = labelme.ai.SegmentAnythingModel( + self._ai_model = SAM.SegmentAnythingModel( name=model.name, encoder_path=gdown.cached_download( url=model.encoder_weight.url, @@ -568,7 +568,7 @@ def mousePressEvent(self, ev): self.current.shape_type = "circle" self.line.points = [pos, pos] if ( - self.createMode in ["ai_polygon","ai_mask"] + self.createMode in ["ai_polygon", "ai_mask"] and is_shift_pressed ): self.line.point_labels = [0, 0] @@ -845,10 +845,10 @@ def paintEvent(self, event): # draw crosshair if (self._crosshair[self._createMode] - and self.drawing() - and self.prevMovePoint - and not self.outOfPixmap(self.prevMovePoint) - ): + and self.drawing() + and self.prevMovePoint + and not self.outOfPixmap(self.prevMovePoint) + ): p.setPen(QtGui.QColor(0, 0, 0)) p.drawLine( 0, diff --git a/annolid/segmentation/SAM/__init__.py b/annolid/segmentation/SAM/__init__.py index c4be08d6..133c4bde 100644 --- a/annolid/segmentation/SAM/__init__.py +++ b/annolid/segmentation/SAM/__init__.py @@ -16,6 +16,17 @@ Weight = collections.namedtuple("Weight", ["url", "md5"]) MODELS = [ + Model( + name="Segment-Anything (Edge)", + encoder_weight=Weight( + url="https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam_3x_encoder.onnx", # NOQA + md5="e0745d06f3ee9c5e01a667b56a40875b", + ), + decoder_weight=Weight( + url="https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam_3x_decoder.onnx", # NOQA + md5="9fe1d5521b4349ab710e9cc970936970", + ), + ), Model( name="Segment-Anything (speed)", encoder_weight=Weight( diff --git a/annolid/segmentation/SAM/segment_anything.py b/annolid/segmentation/SAM/segment_anything.py index 29a2a8ca..11b10ab1 100644 --- a/annolid/segmentation/SAM/segment_anything.py +++ b/annolid/segmentation/SAM/segment_anything.py @@ -7,7 +7,7 @@ import onnxruntime import PIL.Image import skimage.measure - +import cv2 from labelme.logger import logger @@ -72,6 +72,18 @@ def predict_polygon_from_points(self, points, point_labels): ) return polygon + def predict_mask_from_points(self, points, point_labels): + image_embedding = self._get_image_embedding() + mask = _compute_mask_from_points( + image_size=self._image_size, + decoder_session=self._decoder_session, + image=self._image, + image_embedding=image_embedding, + points=points, + point_labels=point_labels, + ) + return mask + def _compute_scale_to_resize_image(image_size, image): height, width = image.shape[:2] @@ -99,6 +111,17 @@ def _resize_image(image_size, image): return scale, scaled_image +def postprocess_masks(mask, img_size, input_size, original_size): + mask = mask.squeeze(0).transpose(1, 2, 0) + mask = cv2.resize(mask, (img_size, img_size), + interpolation=cv2.INTER_LINEAR) + mask = mask[:input_size[0], :input_size[1], :] + mask = cv2.resize( + mask, (original_size[1], original_size[0]), interpolation=cv2.INTER_LINEAR) + mask = mask.transpose(2, 0, 1)[None, :, :, :] + return mask + + def _compute_image_embedding(image_size, encoder_session, image): image = imgviz.asrgb(image) @@ -115,8 +138,12 @@ def _compute_image_embedding(image_size, encoder_session, image): ), ) x = x.transpose(2, 0, 1)[None, :, :, :] - - output = encoder_session.run(output_names=None, input_feed={"x": x}) + input_names = [input.name for input in encoder_session.get_inputs()] + if input_names[0] == 'image': + output = encoder_session.run( + output_names=None, input_feed={"image": x}) + else: + output = encoder_session.run(output_names=None, input_feed={"x": x}) image_embedding = output[0] return image_embedding @@ -128,7 +155,7 @@ def _get_contour_length(contour): return np.linalg.norm(contour_end - contour_start, axis=1).sum() -def _compute_polygon_from_points( +def _compute_mask_from_points( image_size, decoder_session, image, image_embedding, points, point_labels ): input_point = np.array(points, dtype=np.float32) @@ -152,28 +179,61 @@ def _compute_polygon_from_points( onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.array([-1], dtype=np.float32) - decoder_inputs = { - "image_embeddings": image_embedding, - "point_coords": onnx_coord, - "point_labels": onnx_label, - "mask_input": onnx_mask_input, - "has_mask_input": onnx_has_mask_input, - "orig_im_size": np.array(image.shape[:2], dtype=np.float32), - } + input_names = [input.name for input in decoder_session.get_inputs()] + if len(input_names) <= 3: + outputs = decoder_session.run(None, { + 'image_embeddings': image_embedding, + 'point_coords': onnx_coord, + 'point_labels': onnx_label, + }) + scores, masks = outputs + masks = postprocess_masks( + masks, image_size, (new_height, new_width), np.array(image.shape[:2])) - masks, _, _ = decoder_session.run(None, decoder_inputs) + else: + decoder_inputs = { + "image_embeddings": image_embedding, + "point_coords": onnx_coord, + "point_labels": onnx_label, + "mask_input": onnx_mask_input, + "has_mask_input": onnx_has_mask_input, + "orig_im_size": np.array(image.shape[:2], dtype=np.float32), + } + + masks, _, _ = decoder_session.run(None, decoder_inputs) mask = masks[0, 0] # (1, 1, H, W) -> (H, W) mask = mask > 0.0 + + MIN_SIZE_RATIO = 0.05 + skimage.morphology.remove_small_objects( + mask, min_size=mask.sum() * MIN_SIZE_RATIO, out=mask + ) + if 0: imgviz.io.imsave( "mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image)) ) + return mask + + +def _compute_polygon_from_points( + image_size, decoder_session, image, image_embedding, points, point_labels +): + mask = _compute_mask_from_points( + image_size=image_size, + decoder_session=decoder_session, + image=image, + image_embedding=image_embedding, + points=points, + point_labels=point_labels, + ) contours = skimage.measure.find_contours(np.pad(mask, pad_width=1)) contour = max(contours, key=_get_contour_length) + POLYGON_APPROX_TOLERANCE = 0.004 polygon = skimage.measure.approximate_polygon( coords=contour, - tolerance=np.ptp(contour, axis=0).max() / 100, + tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE, ) polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1)) polygon = polygon[:-1] # drop last point that is duplicate of first point