diff --git a/setup.py b/setup.py index a71c64c..6512a73 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='yoflo', - version='0.4.3', + version='0.5.0', packages=find_packages(), include_package_data=True, install_requires=[ diff --git a/yoflo/yoflo.py b/yoflo/yoflo.py index 16e1c89..50bb6f9 100644 --- a/yoflo/yoflo.py +++ b/yoflo/yoflo.py @@ -28,6 +28,7 @@ def __init__( inference_limit=None, class_names=None, webcam_indices=None, + rtsp_urls=None, ): """Initialize the YO-FLO class with configuration options.""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -49,6 +50,7 @@ def __init__( self.last_inference_time = 0 self.inference_phrases = [] self.webcam_indices = webcam_indices if webcam_indices else [0] + self.rtsp_urls = rtsp_urls if rtsp_urls else [] if model_path: self.init_model(model_path) @@ -62,7 +64,6 @@ def init_model(self, model_path): f"Model path {os.path.abspath(model_path)} is not a directory." ) return - try: logging.info(f"Attempting to load model from {os.path.abspath(model_path)}") self.model = ( @@ -216,34 +217,45 @@ def download_model(self): logging.error(f"Error downloading model: {e}") def start_webcam_detection(self): - """Start separate threads for each specified webcam.""" + """Start separate threads for each specified webcam or RTSP stream.""" try: if self.webcam_threads: logging.warning("Webcam detection is already running.") return self.stop_webcam_flag.clear() - for index in self.webcam_indices: - thread = threading.Thread( - target=self._webcam_detection_thread, args=(index,) - ) - thread.start() - self.webcam_threads.append(thread) + if self.rtsp_urls: + for rtsp_url in self.rtsp_urls: + thread = threading.Thread( + target=self._webcam_detection_thread, args=(rtsp_url,) + ) + thread.start() + self.webcam_threads.append(thread) + else: + for index in self.webcam_indices: + thread = threading.Thread( + target=self._webcam_detection_thread, args=(index,) + ) + thread.start() + self.webcam_threads.append(thread) except Exception as e: logging.error(f"Error starting webcam detection: {e}") - def _webcam_detection_thread(self, index): - """Run the webcam detection loop in a separate thread for a specific webcam.""" + def _webcam_detection_thread(self, source): + """Run the webcam detection loop in a separate thread for a specific webcam or RTSP stream.""" try: - cap = cv2.VideoCapture(index) + if isinstance(source, str): + cap = cv2.VideoCapture(source) + else: + cap = cv2.VideoCapture(source) if not cap.isOpened(): - logging.error(f"Error: Could not open webcam {index}.") + logging.error(f"Error: Could not open video source {source}.") return - window_name = f"Object Detection Webcam {index}" + window_name = f"Object Detection Source {source}" while not self.stop_webcam_flag.is_set(): ret, frame = cap.read() if not ret: logging.error( - f"Error: Failed to capture image from webcam {index}." + f"Error: Failed to capture image from source {source}." ) break image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) @@ -268,7 +280,7 @@ def _webcam_detection_thread(self, index): self.pretty_print_detections(filtered_detections) else: logging.info( - f"Detections from webcam {index}: {filtered_detections}" + f"Detections from source {source}: {filtered_detections}" ) if not self.headless: frame = self.plot_bbox(frame, filtered_detections) @@ -279,7 +291,7 @@ def _webcam_detection_thread(self, index): self.save_screenshot(frame) if self.log_to_file_active: self.log_alert( - f"Detections from webcam {index}: {filtered_detections}" + f"Detections from source {source}: {filtered_detections}" ) elif self.phrase: results = self.run_expression_comprehension(image_pil, self.phrase) @@ -296,19 +308,19 @@ def _webcam_detection_thread(self, index): if clean_result in ["yes", "no"] and self.log_to_file_active: if self.log_to_file_active: self.log_alert( - f"Expression Comprehension from webcam {index}: {clean_result} at {datetime.now()}" + f"Expression Comprehension from source {source}: {clean_result} at {datetime.now()}" ) if self.inference_phrases: inference_result, phrase_results = self.evaluate_inference_chain( image_pil ) logging.info( - f"Inference Chain result from webcam {index}: {inference_result}, Details: {phrase_results}" + f"Inference Chain result from source {source}: {inference_result}, Details: {phrase_results}" ) if self.pretty_print: for idx, result in enumerate(phrase_results): logging.info( - f"Inference {idx + 1} from webcam {index}: {'PASS' if result else 'FAIL'}" + f"Inference {idx + 1} from source {source}: {'PASS' if result else 'FAIL'}" ) self.inference_count += 1 self.update_inference_rate() @@ -321,9 +333,9 @@ def _webcam_detection_thread(self, index): if not self.headless: cv2.destroyWindow(window_name) except cv2.error as e: - logging.error(f"OpenCV error in webcam detection thread {index}: {e}") + logging.error(f"OpenCV error in detection thread {source}: {e}") except Exception as e: - logging.error(f"Error in webcam detection thread {index}: {e}") + logging.error(f"Error in detection thread {source}: {e}") def stop_webcam_detection(self): """Stop all webcam detection threads.""" @@ -399,7 +411,6 @@ def evaluate_inference_chain(self, image): if not self.inference_phrases: logging.error("No inference phrases set.") return "FAIL", [] - results = [] for phrase in self.inference_phrases: result = self.run_expression_comprehension(image, phrase) @@ -408,7 +419,6 @@ def evaluate_inference_chain(self, image): results.append(True) else: results.append(False) - overall_result = "PASS" if results.count(True) >= 2 else "FAIL" return overall_result, results except Exception as e: @@ -483,6 +493,13 @@ def main(): type=int, help="Specify the indices of the webcams to use (e.g., 0 1 2).", ) + parser.add_argument( + "-rtsp", + "--rtsp_urls", + nargs="+", + type=str, + help="Specify the RTSP URLs for the video streams.", + ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( @@ -501,10 +518,10 @@ def main(): args = parser.parse_args() if not args.model_path and not args.download_model: parser.error("You must specify either --model_path or --download_model.") - try: setup_logging(args.log_to_file) webcam_indices = args.webcam_indices if args.webcam_indices else [0] + rtsp_urls = args.rtsp_urls if args.rtsp_urls else [] if args.download_model: yo_flo = YOFLO( display_inference_speed=args.display_inference_speed, @@ -512,6 +529,7 @@ def main(): inference_limit=args.inference_limit, class_names=args.object_detection, webcam_indices=webcam_indices, + rtsp_urls=rtsp_urls, ) yo_flo.download_model() else: @@ -528,8 +546,8 @@ def main(): inference_limit=args.inference_limit, class_names=args.object_detection, webcam_indices=webcam_indices, + rtsp_urls=rtsp_urls, ) - if args.phrase: yo_flo.phrase = args.phrase if args.inference_chain: @@ -545,10 +563,8 @@ def main(): time.sleep(1) except KeyboardInterrupt: yo_flo.stop_webcam_detection() - except Exception as e: logging.error(f"An error occurred during main loop: {e}") - else: input("Press Enter to stop...") yo_flo.stop_webcam_detection()