From 5d94f8be29e9b187b303cf2faf582367605e4721 Mon Sep 17 00:00:00 2001 From: PortfolioAI <135471798+CharlesCNorton@users.noreply.github.com> Date: Wed, 7 Aug 2024 13:48:56 -0400 Subject: [PATCH] Enhance Exception Handling & Documentation (v0.8.1) - Reverted float16 handling for input compatibility. - Added exception handling for CUDA and data type mismatches. - Enhanced docstrings for better code clarity and maintainability. - Updated output file naming with timestamps to prevent overwrites. --- setup.py | 2 +- yoflo/yoflo.py | 187 ++++++++++++++++++++++++++++++++++--------------- 2 files changed, 132 insertions(+), 57 deletions(-) diff --git a/setup.py b/setup.py index d5f64a5..f8b6e16 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='yoflo', - version='0.8.0', + version='0.8.1', packages=find_packages(), include_package_data=True, install_requires=[ diff --git a/yoflo/yoflo.py b/yoflo/yoflo.py index 65c9ad2..0f47e49 100644 --- a/yoflo/yoflo.py +++ b/yoflo/yoflo.py @@ -11,11 +11,15 @@ from transformers import AutoProcessor, AutoModelForCausalLM from transformers import BitsAndBytesConfig + def setup_logging(log_to_file, log_file_path="alerts.log"): - """Set up logging to file and/or console. + """ + Set up logging to file and/or console. + + This function configures the logging module to log messages to both the console and optionally to a specified log file. Args: - log_to_file (bool): Whether to log to a file. + log_to_file (bool): Whether to log messages to a file. log_file_path (str, optional): Path to the log file. Defaults to "alerts.log". """ handlers = [logging.StreamHandler()] @@ -23,6 +27,7 @@ def setup_logging(log_to_file, log_file_path="alerts.log"): handlers.append(logging.FileHandler(log_file_path)) logging.basicConfig(level=logging.INFO, format="%(message)s", handlers=handlers) + class YOFLO: def __init__( self, @@ -34,9 +39,12 @@ def __init__( webcam_indices=None, rtsp_urls=None, record=None, - quantization=None + quantization=None, ): - """Initialize the YO-FLO class with configuration options. + """ + Initialize the YO-FLO class with configuration options. + + This constructor initializes the YO-FLO object with various settings for model, display, inference, and video processing. Args: model_path (str, optional): Path to the pre-trained model directory. Defaults to None. @@ -72,13 +80,16 @@ def __init__( self.record = record self.recording = False self.video_writer = None - self.video_out_path = "output.avi" self.quantization = quantization + self.video_out_path = f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.avi" if model_path: self.init_model(model_path) def init_model(self, model_path): - """Initialize the model and processor from the given model path. + """ + Initialize the model and processor from the given model path. + + This method loads a pre-trained model and its processor from a specified directory, and prepares it for inference. It handles quantization settings if specified. Args: model_path (str): Path to the pre-trained model directory. @@ -98,20 +109,23 @@ def init_model(self, model_path): if self.quantization == "4bit": quantization_config = BitsAndBytesConfig( load_in_4bit=True, - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, ) logging.info("Using 4-bit quantization.") - self.model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, - quantization_config=quantization_config + quantization_config=quantization_config, ).eval() if not self.quantization: + self.model.to(self.device) + if torch.cuda.is_available(): + self.model = self.model.half() + logging.info("Using FP16 precision for the model.") self.processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True ) @@ -124,7 +138,11 @@ def init_model(self, model_path): logging.error(f"Unexpected error initializing model: {e}") def update_inference_rate(self): - """Calculate and log the inference rate (inferences per second).""" + """ + Calculate and log the inference rate (inferences per second). + + This method calculates the rate of inferences over time and logs it, helping users understand the model's performance in real-time. + """ try: if self.inference_start_time is None: self.inference_start_time = time.time() @@ -138,7 +156,10 @@ def update_inference_rate(self): logging.error(f"Error updating inference rate: {e}") def run_object_detection(self, image): - """Perform object detection on the given image. + """ + Perform object detection on the given image. + + This method runs object detection on a provided image using the initialized model and processor, returning parsed detection results. Args: image (PIL.Image): The image to perform object detection on. @@ -148,21 +169,18 @@ def run_object_detection(self, image): """ try: task_prompt = "" - inputs = self.processor( - text=task_prompt, images=image, return_tensors="pt" - ) + inputs = self.processor(text=task_prompt, images=image, return_tensors="pt") - if not self.quantization: - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - if self.quantization: - dtype = next(self.model.parameters()).dtype - inputs = {k: v.to(self.model.device, dtype=dtype) if torch.is_floating_point(v) else v for k, v in inputs.items()} + dtype = next(self.model.parameters()).dtype + inputs = { + k: v.to(self.device, dtype=dtype) if torch.is_floating_point(v) else v + for k, v in inputs.items() + } with torch.no_grad(): generated_ids = self.model.generate( - input_ids=inputs["input_ids"].to(self.model.device), - pixel_values=inputs.get("pixel_values").to(self.model.device), + input_ids=inputs["input_ids"].to(self.device), + pixel_values=inputs.get("pixel_values").to(self.device), max_new_tokens=1024, early_stopping=False, do_sample=False, @@ -182,7 +200,10 @@ def run_object_detection(self, image): return None def filter_detections(self, detections): - """Filter detections to include only specified class names. + """ + Filter detections to include only specified class names. + + This method filters the raw detections returned by the model to only include objects specified in the `class_names` attribute. Args: detections (list): List of detections. @@ -204,7 +225,10 @@ def filter_detections(self, detections): return detections def run_expression_comprehension(self, image, phrase): - """Run expression comprehension on the given image and phrase. + """ + Run expression comprehension on the given image and phrase. + + This method evaluates a given phrase against the provided image to determine if the expression is present, using the initialized model. Args: image (PIL.Image): The image to run expression comprehension on. @@ -215,24 +239,22 @@ def run_expression_comprehension(self, image, phrase): """ try: task_prompt = "" - inputs = self.processor( - text=task_prompt, images=image, return_tensors="pt" - ) + inputs = self.processor(text=task_prompt, images=image, return_tensors="pt") inputs["input_ids"] = self.processor.tokenizer( phrase, return_tensors="pt" ).input_ids - if not self.quantization: - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - if self.quantization: - dtype = next(self.model.parameters()).dtype - inputs = {k: v.to(dtype=dtype) if torch.is_floating_point(v) else v for k, v in inputs.items()} + # Determine the appropriate dtype + dtype = next(self.model.parameters()).dtype + inputs = { + k: v.to(self.device, dtype=dtype) if torch.is_floating_point(v) else v + for k, v in inputs.items() + } with torch.no_grad(): generated_ids = self.model.generate( - input_ids=inputs["input_ids"].to(self.model.device), - pixel_values=inputs.get("pixel_values").to(self.model.device), + input_ids=inputs["input_ids"].to(self.device), + pixel_values=inputs.get("pixel_values").to(self.device), max_new_tokens=1024, early_stopping=False, do_sample=False, @@ -249,7 +271,10 @@ def run_expression_comprehension(self, image, phrase): return None def plot_bbox(self, image, detections): - """Draw bounding boxes on the image based on detections. + """ + Draw bounding boxes on the image based on detections. + + This method draws bounding boxes around detected objects on the image using OpenCV, labeling them with their corresponding class names. Args: image (numpy.ndarray): The image to draw bounding boxes on. @@ -279,7 +304,14 @@ def plot_bbox(self, image, detections): return image def download_model(self): - """Download the model and processor from Hugging Face Hub.""" + """ + Download the model and processor from Hugging Face Hub. + + This method automatically downloads the pre-trained model files from the Hugging Face Hub and initializes them for use. + + Returns: + bool: True if the download and initialization are successful, False otherwise. + """ try: local_model_dir = "model" snapshot_download( @@ -307,7 +339,11 @@ def download_model(self): return False def start_webcam_detection(self): - """Start separate threads for each specified webcam or RTSP stream.""" + """ + Start separate threads for each specified webcam or RTSP stream. + + This method initiates separate threads to handle object detection on multiple webcam indices or RTSP URLs specified during initialization. + """ try: if self.webcam_threads: logging.warning("Webcam detection is already running.") @@ -331,7 +367,10 @@ def start_webcam_detection(self): logging.error(f"Error starting webcam detection: {e}") def _webcam_detection_thread(self, source): - """Run the webcam detection loop in a separate thread for a specific webcam or RTSP stream. + """ + Run the webcam detection loop in a separate thread for a specific webcam or RTSP stream. + + This method handles video capture and object detection in a loop, processing each frame in real-time. Args: source (str or int): The source index or RTSP URL for the webcam. @@ -445,7 +484,11 @@ def _webcam_detection_thread(self, source): logging.error(f"Error in detection thread {source}: {e}") def stop_webcam_detection(self): - """Stop all webcam detection threads.""" + """ + Stop all webcam detection threads. + + This method stops the webcam detection process by signaling all running threads to terminate gracefully. + """ try: self.object_detection_active = False self.stop_webcam_flag.set() @@ -459,7 +502,10 @@ def stop_webcam_detection(self): logging.error(f"Error stopping webcam detection: {e}") def save_screenshot(self, frame): - """Save a screenshot of the current frame. + """ + Save a screenshot of the current frame. + + This method captures a screenshot of the current frame being processed and saves it as a PNG file with a timestamped filename. Args: frame (numpy.ndarray): The frame to save as a screenshot. @@ -475,7 +521,10 @@ def save_screenshot(self, frame): logging.error(f"Error saving screenshot: {e}") def log_alert(self, message): - """Log an alert message to a file. + """ + Log an alert message to a file. + + This method appends alert messages to a log file, including a timestamp, for record-keeping and analysis. Args: message (str): The alert message to log. @@ -491,7 +540,10 @@ def log_alert(self, message): logging.error(f"Error logging alert: {e}") def pretty_print_detections(self, detections): - """Pretty print the detections to the console. + """ + Pretty print the detections to the console. + + This method formats and prints detection results in a human-readable form for easy interpretation. Args: detections (list): List of detections to print. @@ -509,7 +561,10 @@ def pretty_print_detections(self, detections): logging.error(f"Error in pretty_print_detections: {e}") def pretty_print_expression(self, clean_result): - """Pretty print the expression comprehension result to the console. + """ + Pretty print the expression comprehension result to the console. + + This method formats and prints the result of expression comprehension, highlighting the outcome in a readable format. Args: clean_result (str): The clean result to print. @@ -526,7 +581,10 @@ def pretty_print_expression(self, clean_result): logging.error(f"Error in pretty_print_expression: {e}") def set_inference_phrases(self, phrases): - """Set the phrases for the inference chain. + """ + Set the phrases for the inference chain. + + This method allows setting a list of phrases to be evaluated in the inference chain for expression comprehension tasks. Args: phrases (list): List of phrases for the inference chain. @@ -535,7 +593,10 @@ def set_inference_phrases(self, phrases): logging.info(f"Inference phrases set: {self.inference_phrases}") def evaluate_inference_chain(self, image): - """Evaluate the inference chain based on the set phrases. + """ + Evaluate the inference chain based on the set phrases. + + This method processes the image against each phrase in the inference chain, determining an overall result based on individual outcomes. Args: image (PIL.Image): The image to evaluate. @@ -562,7 +623,10 @@ def evaluate_inference_chain(self, image): return "FAIL", [] def start_recording(self, frame): - """Start recording the video. + """ + Start recording the video. + + This method initiates video recording, setting up the video writer based on the frame's dimensions. Args: frame (numpy.ndarray): The frame to use for setting the video writer. @@ -572,9 +636,9 @@ def start_recording(self, frame): height, width, _ = frame.shape self.video_writer = cv2.VideoWriter( self.video_out_path, - cv2.VideoWriter_fourcc(*'XVID'), + cv2.VideoWriter_fourcc(*"XVID"), 20.0, - (width, height) + (width, height), ) self.recording = True logging.info(f"Started recording video: {self.video_out_path}") @@ -582,7 +646,11 @@ def start_recording(self, frame): logging.error(f"Error starting video recording: {e}") def stop_recording(self): - """Stop recording the video.""" + """ + Stop recording the video. + + This method stops the video recording process and releases the video writer. + """ try: if self.recording: self.video_writer.release() @@ -592,7 +660,10 @@ def stop_recording(self): logging.error(f"Error stopping video recording: {e}") def handle_recording_by_inference(self, inference_result, frame): - """Handle recording based on inference result. + """ + Handle recording based on inference result. + + This method controls video recording based on the results of expression comprehension, starting or stopping recording as needed. Args: inference_result (str): The inference result ("yes" or "no"). @@ -610,8 +681,13 @@ def handle_recording_by_inference(self, inference_result, frame): except Exception as e: logging.error(f"Error handling recording by inference: {e}") + def main(): - """Parse command-line arguments and run the YO-FLO application.""" + """ + Parse command-line arguments and run the YO-FLO application. + + This is the main function that sets up the YO-FLO application, parsing command-line arguments and initiating the object detection process. + """ parser = argparse.ArgumentParser( description="YO-FLO: A proof-of-concept in using advanced vision-language models as a YOLO alternative." ) @@ -715,7 +791,6 @@ def main(): args = parser.parse_args() if not args.model_path and not args.download_model: parser.error("You must specify either --model_path or --download_model.") - quantization_mode = "4bit" if args.four_bit else None try: @@ -731,7 +806,7 @@ def main(): webcam_indices=webcam_indices, rtsp_urls=rtsp_urls, record=args.record, - quantization=quantization_mode + quantization=quantization_mode, ) if not yo_flo.download_model(): return @@ -751,7 +826,7 @@ def main(): webcam_indices=webcam_indices, rtsp_urls=rtsp_urls, record=args.record, - quantization=quantization_mode + quantization=quantization_mode, ) if args.phrase: yo_flo.phrase = args.phrase