Skip to content

Commit

Permalink
Implement the VideoProcessor class to enhance video frame processing …
Browse files Browse the repository at this point in the history
…by leveraging the capabilities of EdgeSAM.
  • Loading branch information
healthonrails committed Dec 22, 2023
1 parent e921510 commit 08d1649
Showing 1 changed file with 138 additions and 0 deletions.
138 changes: 138 additions & 0 deletions annolid/segmentation/SAM/edge_sam_bg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import cv2
from pathlib import Path
from segment_anything import SegmentAnythingModel
from annolid.data.videos import CV2Video
from annolid.utils.files import find_most_recent_file
import json
from annolid.gui.shape import Shape
from annolid.annotation.keypoints import save_labels


class VideoProcessor:
"""
A class for processing video frames using the Segment-Anything model.
"""

def __init__(self, video_path, encoder_path, decoder_path):
"""
Initialize the VideoProcessor.
Parameters:
- video_path (str): Path to the video file.
- encoder_path (str): Path to the encoder model file.
- decoder_path (str): Path to the decoder model file.
"""
self.video_path = video_path
self.video_folder = Path(video_path).with_suffix("")
self.video_loader = CV2Video(video_path)
self.edge_sam = self.get_model(encoder_path, decoder_path)

def get_model(self, encoder_path, decoder_path):
"""
Load the Segment-Anything model.
Parameters:
- encoder_path (str): Path to the encoder model file.
- decoder_path (str): Path to the decoder model file.
Returns:
- SegmentAnythingModel: The loaded model.
"""
name = "Segment-Anything (Edge)"
model = SegmentAnythingModel(name, encoder_path, decoder_path)
return model

def load_json_file(self, json_file_path):
"""
Load JSON file containing shapes and labels.
Parameters:
- json_file_path (str): Path to the JSON file.
Returns:
- tuple: A tuple containing two dictionaries (points_dict, point_labels_dict).
"""
with open(json_file_path, 'r') as json_file:
data = json.load(json_file)

points_dict = {}
point_labels_dict = {}

for shape in data.get('shapes', []):
label = shape.get('label')
points = shape.get('points', [])

if label and points:
points_dict[label] = points
# You can customize this if needed
point_labels_dict[label] = 1

return points_dict, point_labels_dict

def process_frame(self, frame_number):
"""
Process a single frame of the video.
Parameters:
- frame_number (int): Frame number to process.
"""
cur_frame = self.video_loader.load_frame(frame_number)

height, width, _ = cur_frame.shape

points_dict, _ = self.load_json_file(self.get_most_recent_file())
label_list = []

# Example usage of predict_polygon_from_points
for label, points in points_dict.items():
point_labels = [1] * len(points)
self.edge_sam.set_image(cur_frame)
polygon = self.edge_sam.predict_polygon_from_points(
points, point_labels)

# Save the LabelMe JSON to a file
p_shape = Shape(label, shape_type='polygon', flags={})
for x, y in polygon:
# do not add 0,0 to the list
if x >= 1 and y >= 1:
p_shape.addPoint((x, y))
label_list.append(p_shape)

filename = self.video_folder / \
(self.video_folder.name + f"_{frame_number:0>{9}}.json")
img_filename = str(filename.with_suffix('.png'))
cur_frame = cv2.cvtColor(cur_frame, cv2.COLOR_BGR2RGB)
cv2.imwrite(img_filename, cur_frame)
save_labels(filename=filename, imagePath=img_filename, label_list=label_list,
height=height, width=width)

def process_video_frames(self, start_frame, end_frame, step):
"""
Process multiple frames of the video.
Parameters:
- start_frame (int): Starting frame number.
- end_frame (int): Ending frame number.
- step (int): Step between frames.
"""
for i in range(start_frame, end_frame + 1, step):
self.process_frame(i)

def get_most_recent_file(self):
"""
Find the most recent file in the video folder.
Returns:
- str: Path to the most recent file.
"""
return find_most_recent_file(self.video_folder)


if __name__ == '__main__':
# Usage
video_path = "monkey_u.mp4"
encoder_path = "edge_sam_3x_encoder.onnx"
decoder_path = "edge_sam_3x_decoder.onnx"

video_processor = VideoProcessor(video_path, encoder_path, decoder_path)
video_processor.process_video_frames(30, 100, 10)

0 comments on commit 08d1649

Please sign in to comment.