From 17f88a4b4e9aac3caa997d1e4bdf1556ba18ef86 Mon Sep 17 00:00:00 2001 From: healthonrails Date: Fri, 6 Dec 2024 17:36:50 -0500 Subject: [PATCH] feat: add _find_best_model to locate 'best.pt' in common directories Added to search for the 'best.pt' model file in typical folder structures. These include , , or directories that may be downloaded from Colab and extracted into the Downloads or current directory. If the model is not found, the function falls back to a default model name, ensuring flexibility and ease of use. --- annolid/segmentation/yolos.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/annolid/segmentation/yolos.py b/annolid/segmentation/yolos.py index 8995ac7..dfb2567 100644 --- a/annolid/segmentation/yolos.py +++ b/annolid/segmentation/yolos.py @@ -16,15 +16,36 @@ def __init__(self, model_name, model_type, class_names=None): model_name (str): Path or identifier for the model file. model_type (str): Type of model ('yolo' or 'sam'). class_names (list, optional): List of class names for YOLO. - Defaults to None. Only provide + Defaults to None. Only provide if the model doesn't have classes built-in and you need to set them. """ self.model_type = model_type + model_name = self._find_best_model(model_name) self.model = self._load_model(model_name, class_names) self.frame_count = 0 self.track_history = defaultdict(lambda: []) + def _find_best_model(self, model_name): + """ + Searches for 'best.pt' in potential directories and returns its path. + If not found, uses a default model. + """ + search_paths = [ + os.path.expanduser("~/Downloads/best.pt"), + os.path.expanduser( + "~/Downloads/runs/segment/train/weights/best.pt"), + os.path.expanduser("~/Downloads/segment/train/weights/best.pt"), + "runs/segment/train/weights/best.pt", + "segment/train/weights/best.pt" + ] + for path in search_paths: + if os.path.isfile(path): + print(f"Found model: {path}") + return path + print("best.pt not found, using default model") + return model_name + def _load_model(self, model_name, class_names): """Loads the specified model.""" if self.model_type == 'yolo': @@ -153,6 +174,6 @@ def save_yolo_to_labelme(self, yolo_results, frame_shape, output_dir): # Replace with your video path video_path = os.path.expanduser("~/Downloads/IMG_0769.MOV") - # Using a readily available model - yolo_processor = InferenceProcessor("yolo11n-seg.pt", model_type="yolo") + # Automatically find best.pt or use default + yolo_processor = InferenceProcessor(model_type="yolo") yolo_processor.run_inference(video_path)