Skip to content

Commit

Permalink
Refactor: Add polygon segmentation support for YOLO11n-seg
Browse files Browse the repository at this point in the history
Enhanced the YOLO11n-seg inference pipeline to include polygon segmentation output. The extract_yolo_results function now processes and extracts segmentation masks, converting them into polygon representations. These polygons are then included in the final output, enabling more precise and detailed object representation.
  • Loading branch information
healthonrails committed Dec 5, 2024
1 parent ead7327 commit 117ce39
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 90 deletions.
4 changes: 2 additions & 2 deletions annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,8 +1043,8 @@ def _select_sam_model_name(self):
"CoTracker": "CoTracker",
"sam2_hiera_s": "sam2_hiera_s",
"sam2_hiera_l": "sam2_hiera_l",
"YOLO11n": "yolo11n.pt",
"YOLO11x": "yolo11x.pt",
"YOLO11n": "yolo11n-seg.pt",
"YOLO11x": "yolo11x-seg.pt",
}
default_model_name = "Segment-Anything (Edge)"

Expand Down
185 changes: 97 additions & 88 deletions annolid/segmentation/yolos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import numpy as np
from ultralytics import YOLO, SAM
from annolid.gui.shape import Shape
from annolid.annotation.keypoints import save_labels
from collections import defaultdict


class InferenceProcessor:
Expand All @@ -12,119 +14,129 @@ def __init__(self, model_name, model_type, class_names=None):
Args:
model_name (str): Path or identifier for the model file.
model_type (str): Type of model ('yolo' or 'sam').
class_names (list, optional): List of class names for YOLO.
class_names (list, optional): List of class names for YOLO.
Defaults to None. Only provide
if the model doesn't have classes
built-in and you need to set them.
"""
self.model_type = model_type
self.model = self._load_model(model_name, class_names)
self.frame_count = 0 # Initialize the frame counter
self.frame_count = 0
self.track_history = defaultdict(lambda: [])

def _load_model(self, model_name, class_names):
"""
Loads the specified model based on the model type.
Args:
model_name (str): Path or identifier for the model file.
class_names (list, optional): List of class names for YOLO.
Returns:
A YOLO or SAM model instance.
"""
"""Loads the specified model."""
if self.model_type == 'yolo':
model = YOLO(model_name)
if 'world' in model_name and class_names:
if class_names: # Only set classes if provided
model.set_classes(class_names)
return model
elif self.model_type == 'sam':
model = SAM(model_name)
model.info() # Optional: Display model information
model.info()
return model
else:
raise ValueError("Unsupported model type. Use 'yolo' or 'sam'.")

def run_inference(self, source):
"""
Runs inference on the specified source and saves results to LabelMe JSON.
Args:
source (str): Path to the video or image source.
"""
# Ensure the output directory exists
"""Runs inference and saves results to LabelMe JSON."""
output_directory = os.path.splitext(source)[0]
os.makedirs(output_directory, exist_ok=True)

results = self.model(source, stream=True)
results = self.model.track(source, persist=True, stream=True)

# Process each frame result
for result in results:
frame_shape = (result.orig_shape[0], result.orig_shape[1], 3)
id_to_labels = {0: "mouse", 1: "teaball"} # Example label map
yolo_results = self.extract_yolo_results(result)
for result in results: # Corrected: Iterate through results generator
# Check if boxes exist
if result.boxes is not None and len(result.boxes):
frame_shape = (result.orig_shape[0], result.orig_shape[1], 3)
yolo_results = self.extract_yolo_results(result)
self.save_yolo_to_labelme(
yolo_results, frame_shape, output_directory)

self.save_yolo_to_labelme(
yolo_results, id_to_labels, frame_shape, output_directory
)
return f"Done#{self.frame_count}"

def extract_yolo_results(self, result):
"""
Extracts YOLO results from the inference result object.
"""Extracts YOLO results, emulating boxes if none are found."""
yolo_results = []

Args:
result: YOLO result object.
# Emulate boxes if none found, otherwise use actual boxes
if not result.boxes:
return yolo_results
else:
boxes = result.boxes.xywh.cpu()
track_ids = result.boxes.id.int().cpu().tolist() if result.boxes.id is not None else [
"" for _ in range(len(boxes))] # Check for track_ids
masks = result.masks
names = result.names
confidences = result.boxes.conf.cpu().tolist() if result.boxes.conf is not None else [
0.0 for _ in range(len(boxes))] # Check for confidences

for box, track_id, mask, name, conf in zip(boxes,
track_ids,
masks,
names, confidences):
x, y, w, h = box.tolist()

# Get the track history (will be empty if track_id is "")
track = self.track_history[track_id]
# Store only if track_id is not empty
track.append((float(x), float(y)))
if len(track) > 30:
track.pop(0)

x1, y1 = x - w / 2, y - h / 2
x2, y2 = x + w / 2, y + h / 2

# Include confidence in the label
box_label = f"{name}_{track_id} conf:{conf:.2f}"
box_points = [[x1, y1], [x2, y2]]
bbox_shape = Shape(box_label, shape_type='rectangle',
description=self.model_type,
flags={},
)
bbox_shape.points = box_points
yolo_results.append(bbox_shape)

# Only create track polygon if history exists and track_id is valid.
if len(track) > 1 and track_id != "":
track_points = np.array(track).tolist()
shape_track = Shape(f"track_{track_id}",
shape_type="polygon",
description=self.model_type,
flags={},
visible=True,
)
shape_track.points = track_points
yolo_results.append(shape_track)

if mask is not None:
try:
polygons = mask.xy
for polygon in polygons:
contour_points = polygon.tolist()
if len(contour_points) > 2:
# Include confidence in segmentation label
seg_label = f"{name}_{track_id} conf:{conf:.2f}"
segmentation_shape = Shape(
seg_label,
shape_type='polygon',
description=self.model_type,
flags={},
visible=True,
)
segmentation_shape.points = contour_points
yolo_results.append(segmentation_shape)
except Exception as e:
print(f"Error processing mask: {e}")

Returns:
A list of dictionaries containing bounding boxes and class IDs.
"""
yolo_results = []
for box in result.boxes:
yolo_results.append({
"cls": box.cls, # Class ID
"xyxy": box.xyxy # Bounding box coordinates
})
return yolo_results

def save_yolo_to_labelme(self, yolo_results, id_to_labels, frame_shape,
output_dir):
"""
Converts YOLO results to LabelMe JSON format and saves them.
Args:
yolo_results (list): YOLO results containing bounding boxes and labels.
id_to_labels (dict): Mapping of object IDs to readable labels.
frame_shape (tuple): Shape of the frame as (height, width, channels).
output_dir (str): Directory to save the LabelMe JSON files.
"""
def save_yolo_to_labelme(self, yolo_results, frame_shape, output_dir):
"""Saves YOLO results to LabelMe JSON."""
height, width, _ = frame_shape

# Construct the JSON filename using the frame count
json_filename = f"{self.frame_count:09d}.json"
output_path = os.path.join(output_dir, json_filename)
label_list = []

for result in yolo_results:
label_id = int(result["cls"].item())
bbox = result["xyxy"].squeeze().tolist()
if bbox:
if id_to_labels is not None:
label = id_to_labels.get(label_id, f"class_{label_id}")
else:
label = f"{label_id}"

# Convert bounding box to a polygon
x_min, y_min, x_max, y_max = bbox
points = [
[x_min, y_min], # Top-left
[x_max, y_min], # Top-right
[x_max, y_max], # Bottom-right
[x_min, y_max], # Bottom-left
]

# Create a MaskShape object
shape = Shape(label=label, flags={},
description="yolo_prediction")

shape.points = points
label_list.append(shape)
label_list = yolo_results # Directly use yolo_results

save_labels(
filename=output_path,
Expand All @@ -134,16 +146,13 @@ def save_yolo_to_labelme(self, yolo_results, id_to_labels, frame_shape,
width=width,
save_image_to_json=False,
)

# Increment the frame counter after saving
self.frame_count += 1


# Example usage
if __name__ == "__main__":
# Replace with your video path
video_path = os.path.expanduser("~/Downloads/IMG_0769.MOV")

yolo_processor = InferenceProcessor(
"yolo11n.pt", model_type="yolo"
)
# Using a readily available model
yolo_processor = InferenceProcessor("yolo11n-seg.pt", model_type="yolo")
yolo_processor.run_inference(video_path)

0 comments on commit 117ce39

Please sign in to comment.