Skip to content

Commit

Permalink
Refine polygon prediction by sampling points within the polygon for i…
Browse files Browse the repository at this point in the history
…mproved accuracy in the subsequent frame
  • Loading branch information
healthonrails committed Jan 5, 2024
1 parent dba0480 commit 244afde
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 6 deletions.
114 changes: 111 additions & 3 deletions annolid/segmentation/SAM/edge_sam_bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,85 @@
from annolid.annotation.keypoints import save_labels
import numpy as np
from collections import deque
from shapely.geometry import Point, Polygon


def uniform_points_inside_polygon(polygon, num_points):
# Get the bounding box of the polygon
min_x, min_y, max_x, max_y = polygon.bounds

# Generate random points within the bounding box
random_points = np.column_stack((np.random.uniform(min_x, max_x, num_points),
np.random.uniform(min_y, max_y, num_points)))

# Filter points that are inside the polygon
inside_points = [
point for point in random_points if Point(point).within(polygon)]

return np.array(inside_points)


def find_polygon_center(polygon_points):
# Convert the list of polygon points to a Shapely Polygon
polygon = Polygon(polygon_points)

# Find the center of the polygon
center = polygon.centroid

return center


def random_sample_near_center(center, num_points, max_distance):
# Randomly sample points near the center
sampled_points = []
for _ in range(num_points):
# Generate random angle and radius
angle = np.random.uniform(0, 2 * np.pi)
radius = np.random.uniform(0, max_distance)

# Calculate new point coordinates
x = center.x + radius * np.cos(angle)
y = center.y + radius * np.sin(angle)

sampled_points.append((x, y))

return np.array(sampled_points)


def random_sample_inside_edges(polygon, num_points):
# Randomly sample points inside the edges of the polygon
sampled_points = []
min_x, min_y, max_x, max_y = polygon.bounds

for _ in range(num_points):
# Generate random point inside the bounding box
x = np.random.uniform(min_x, max_x)
y = np.random.uniform(min_y, max_y)
point = Point(x, y)

# Check if the point is inside the polygon
if point.within(polygon):
sampled_points.append((x, y))

return np.array(sampled_points)


def find_bbox(polygon_points):
# Convert the list of polygon points to a NumPy array
points_array = np.array(polygon_points)

# Calculate the bounding box
min_x, min_y = np.min(points_array, axis=0)
max_x, max_y = np.max(points_array, axis=0)

# Calculate the center point of the bounding box
center_x = (min_x + max_x) / 2
center_y = (min_y + max_y) / 2

# Return the center point as a NumPy array
bbox_center = np.array([(center_x, center_y)])

return bbox_center


class MaxSizeQueue(deque):
Expand Down Expand Up @@ -53,6 +132,7 @@ def __init__(self,
self.num_frames = self.video_loader.total_frames()
self.center_points = MaxSizeQueue(max_size=num_center_points)
self.most_recent_file = self.get_most_recent_file()
self.num_points_inside_edges = 3

def get_model(self,
encoder_path="edge_sam_3x_encoder.onnx",
Expand Down Expand Up @@ -83,6 +163,8 @@ def load_json_file(self, json_file_path):
Returns:
- tuple: A tuple containing two dictionaries (points_dict, point_labels_dict).
"""
import labelme
from annolid.annotation.masks import mask_to_polygons
with open(json_file_path, 'r') as json_file:
data = json.load(json_file)

Expand All @@ -92,10 +174,16 @@ def load_json_file(self, json_file_path):
for shape in data.get('shapes', []):
label = shape.get('label')
points = shape.get('points', [])

if label and points:
mask = labelme.utils.img_b64_to_arr(
shape["mask"]) if shape.get("mask") else None
if mask is not None:
polygons, has_holes = mask_to_polygons(mask)
polys = polygons[0]
points = np.array(
list(zip(polys[0::2], polys[1::2])))

if label and points is not None:
points_dict[label] = points
# You can customize this if needed
point_labels_dict[label] = 1

return points_dict, point_labels_dict
Expand All @@ -119,9 +207,29 @@ def process_frame(self, frame_number):
# Example usage of predict_polygon_from_points
for label, points in points_dict.items():
self.edge_sam.set_image(cur_frame)
orig_points = points
bbox_points = find_bbox(points)
points = calculate_polygon_center(points)

polygon = Polygon(orig_points)
# Find the center of the polygon
# Randomly sample points inside the edges of the polygon
points_inside_edges = random_sample_inside_edges(polygon,
self.num_points_inside_edges)
points_uni = uniform_points_inside_polygon(polygon, 3)

self.center_points.enqueue(points[0])
points = self.center_points.to_numpy()
points = np.concatenate(
(points, bbox_points), axis=0)
if len(points_inside_edges.shape) > 1:
points = np.concatenate(
(points, points_inside_edges), axis=0)
if len(points_uni) > 1:
points = np.concatenate(
(points, points_uni), axis=0
)

point_labels = [1] * len(points)
polygon = self.edge_sam.predict_polygon_from_points(
points, point_labels)
Expand Down
4 changes: 2 additions & 2 deletions annolid/segmentation/SAM/segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def _get_contour_length(contour):


def _compute_mask_from_points(
image_size, decoder_session, image, image_embedding, points, point_labels
):
image_size, decoder_session, image, image_embedding,
points, point_labels):
input_point = np.array(points, dtype=np.float32)
input_label = np.array(point_labels, dtype=np.int32)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ opencv-contrib-python>=4.1.2.30
pycocotools>=2.0.2
simplification==0.6.11
pandas>=1.1.3
shapely>=1.7.1
shapely>=2.0.2
scipy
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'pycocotools>=2.0.2',
'tensorboard>=2.3.0',
'imageio>=2.8.0',
'shapely>=2.0.2',
'imageio-ffmpeg>=0.4.2',
'qimage2ndarray>=1.8',
'simplification==0.6.11',
Expand Down

0 comments on commit 244afde

Please sign in to comment.