Skip to content

Commit

Permalink
Refactor polygon computation from points function
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Jan 4, 2024
1 parent b52595b commit dba0480
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 27 deletions.
14 changes: 8 additions & 6 deletions annolid/segmentation/SAM/edge_sam_bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ def calculate_polygon_center(polygon_vertices):
return np.array([(center_x, center_y)])


class VideoProcessor():
class VideoProcessor:
"""
A class for processing video frames using the Segment-Anything model.
"""

def __init__(self,
video_path,
num_center_points=3
num_center_points=3,
model_name="Segment-Anything (Edge)"
):
"""
Initialize the VideoProcessor.
Expand All @@ -44,19 +45,18 @@ def __init__(self,
- video_path (str): Path to the video file.
- num_center_points (int): number of center points for prompt.
"""
super(VideoProcessor, self).__init__()
self.video_path = video_path
self.video_folder = Path(video_path).with_suffix("")
self.video_loader = CV2Video(video_path)
self.sam_name = model_name
self.edge_sam = self.get_model()
self.num_frames = self.video_loader.total_frames()
self.center_points = MaxSizeQueue(max_size=num_center_points)
self.most_recent_file = self.get_most_recent_file()

def get_model(self,
encoder_path="edge_sam_3x_encoder.onnx",
decoder_path="edge_sam_3x_decoder.onnx",
name="Segment-Anything (Edge)"
decoder_path="edge_sam_3x_decoder.onnx"
):
"""
Load the Segment-Anything model.
Expand All @@ -69,6 +69,7 @@ def get_model(self,
Returns:
- SegmentAnythingModel: The loaded model.
"""
name = self.sam_name
model = SegmentAnythingModel(name, encoder_path, decoder_path)
return model

Expand Down Expand Up @@ -138,7 +139,8 @@ def process_frame(self, frame_number):
self.most_recent_file = filename
img_filename = str(filename.with_suffix('.png'))
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
cv2.imwrite(img_filename, cur_frame)
if not Path(img_filename).exists():
cv2.imwrite(img_filename, cur_frame)
save_labels(filename=filename, imagePath=img_filename, label_list=label_list,
height=height, width=width)

Expand Down
27 changes: 6 additions & 21 deletions annolid/segmentation/SAM/segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import imgviz
import numpy as np
import onnxruntime
import PIL.Image
import skimage.measure
import cv2
from labelme.logger import logger
Expand Down Expand Up @@ -219,6 +218,7 @@ def _compute_mask_from_points(
def _compute_polygon_from_points(
image_size, decoder_session, image, image_embedding, points, point_labels
):
from annolid.annotation.masks import mask_to_polygons
mask = _compute_mask_from_points(
image_size=image_size,
decoder_session=decoder_session,
Expand All @@ -227,23 +227,8 @@ def _compute_polygon_from_points(
points=points,
point_labels=point_labels,
)

contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
contour = max(contours, key=_get_contour_length)
POLYGON_APPROX_TOLERANCE = 0.004
polygon = skimage.measure.approximate_polygon(
coords=contour,
tolerance=np.ptp(contour, axis=0).max() * POLYGON_APPROX_TOLERANCE,
)
polygon = np.clip(polygon, (0, 0), (mask.shape[0] - 1, mask.shape[1] - 1))
polygon = polygon[:-1] # drop last point that is duplicate of first point
if 0:
image_pil = PIL.Image.fromarray(image)
imgviz.draw.line_(image_pil, yx=polygon, fill=(0, 255, 0))
for point in polygon:
imgviz.draw.circle_(
image_pil, center=point, diameter=10, fill=(0, 255, 0)
)
imgviz.io.imsave("contour.jpg", np.asarray(image_pil))

return polygon[:, ::-1] # yx -> xy
polygons, has_holes = mask_to_polygons(mask)
polys = polygons[0]
all_points = np.array(
list(zip(polys[0::2], polys[1::2])))
return all_points

0 comments on commit dba0480

Please sign in to comment.