Skip to content

Commit

Permalink
feat: add _find_best_model to locate 'best.pt' in common directories
Browse files Browse the repository at this point in the history
Added  to search for the 'best.pt' model file in typical folder structures. These include , , or  directories that may be downloaded from Colab and extracted into the Downloads or current directory. If the model is not found, the function falls back to a default model name, ensuring flexibility and ease of use.
  • Loading branch information
healthonrails committed Dec 6, 2024
1 parent ee5fb6c commit 17f88a4
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions annolid/segmentation/yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,36 @@ def __init__(self, model_name, model_type, class_names=None):
model_name (str): Path or identifier for the model file.
model_type (str): Type of model ('yolo' or 'sam').
class_names (list, optional): List of class names for YOLO.
Defaults to None. Only provide
Defaults to None. Only provide
if the model doesn't have classes
built-in and you need to set them.
"""
self.model_type = model_type
model_name = self._find_best_model(model_name)
self.model = self._load_model(model_name, class_names)
self.frame_count = 0
self.track_history = defaultdict(lambda: [])

def _find_best_model(self, model_name):
"""
Searches for 'best.pt' in potential directories and returns its path.
If not found, uses a default model.
"""
search_paths = [
os.path.expanduser("~/Downloads/best.pt"),
os.path.expanduser(
"~/Downloads/runs/segment/train/weights/best.pt"),
os.path.expanduser("~/Downloads/segment/train/weights/best.pt"),
"runs/segment/train/weights/best.pt",
"segment/train/weights/best.pt"
]
for path in search_paths:
if os.path.isfile(path):
print(f"Found model: {path}")
return path
print("best.pt not found, using default model")
return model_name

def _load_model(self, model_name, class_names):
"""Loads the specified model."""
if self.model_type == 'yolo':
Expand Down Expand Up @@ -153,6 +174,6 @@ def save_yolo_to_labelme(self, yolo_results, frame_shape, output_dir):
# Replace with your video path
video_path = os.path.expanduser("~/Downloads/IMG_0769.MOV")

# Using a readily available model
yolo_processor = InferenceProcessor("yolo11n-seg.pt", model_type="yolo")
# Automatically find best.pt or use default
yolo_processor = InferenceProcessor(model_type="yolo")
yolo_processor.run_inference(video_path)

0 comments on commit 17f88a4

Please sign in to comment.