From 4554e36dfe270846cbf224cda26e2492dbfbe12c Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 25 Nov 2024 11:20:03 +0800 Subject: [PATCH 1/2] dev(narugo): add preprocessor --- imgutils/generic/classify.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/imgutils/generic/classify.py b/imgutils/generic/classify.py index 3eb053810b..6166b62c71 100644 --- a/imgutils/generic/classify.py +++ b/imgutils/generic/classify.py @@ -19,7 +19,7 @@ import json import os from threading import Lock -from typing import Tuple, Optional, List, Dict +from typing import Tuple, Optional, List, Dict, Callable import numpy as np from PIL import Image @@ -88,6 +88,9 @@ def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), return data.astype(np.float32) +ImagePreprocessFunc = Callable[[Image.Image], Image.Image] + + class ClassifyModel: """ A class for managing and using classification models. @@ -114,7 +117,8 @@ class ClassifyModel: >>> print(f"Predicted class: {prediction}, Score: {score}") """ - def __init__(self, repo_id: str, hf_token: Optional[str] = None): + def __init__(self, repo_id: str, fn_preprocess: Optional[ImagePreprocessFunc] = None, + hf_token: Optional[str] = None): """ Initialize the ClassifyModel instance. @@ -124,6 +128,7 @@ def __init__(self, repo_id: str, hf_token: Optional[str] = None): :type hf_token: Optional[str], optional """ self.repo_id = repo_id + self._fn_preprocess = fn_preprocess self._model_names = None self._models = {} self._labels = {} @@ -254,6 +259,9 @@ def _raw_predict(self, image: ImageTyping, model_name: str): raise RuntimeError(f'Model {model_name!r} required {[batch, channels, height, width]!r}, ' f'channels not 3.') # pragma: no cover + if self._fn_preprocess: + image = self._fn_preprocess(image) + if isinstance(height, int) and isinstance(width, int): input_ = _img_encode(image, size=(width, height))[None, ...] else: From cb17dac87f172970eb0c23e1a12c4ea22075ebd1 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 25 Nov 2024 11:24:32 +0800 Subject: [PATCH 2/2] dev(narugo): add preprocessor onto the cls model pred --- imgutils/generic/classify.py | 145 +++++++++++++++++++---------------- 1 file changed, 77 insertions(+), 68 deletions(-) diff --git a/imgutils/generic/classify.py b/imgutils/generic/classify.py index 6166b62c71..0ad644f904 100644 --- a/imgutils/generic/classify.py +++ b/imgutils/generic/classify.py @@ -3,15 +3,9 @@ This module provides utilities and classes for working with classification models, particularly those stored in Hugging Face repositories. It includes functions for -image encoding, model loading, and prediction, as well as a main `ClassifyModel` class +image encoding, model loading, and prediction, as well as a main ClassifyModel class that manages the interaction with classification models. -Key components: - -- Image encoding and preprocessing -- ClassifyModel: A class for managing and using classification models -- Utility functions for making predictions with classification models - The module is designed to work with ONNX models and supports various image input formats. It also handles token-based authentication for accessing private Hugging Face repositories. """ @@ -45,9 +39,12 @@ def _check_gradio_env(): """ - Check if the Gradio library is installed and available. + Verify that Gradio library is properly installed and available. + + This function checks if the Gradio package is accessible for creating + web-based demos. If Gradio is not found, it provides instructions for installation. - :raises EnvironmentError: If Gradio is not installed. + :raises EnvironmentError: If Gradio package is not installed in the environment. """ if gr is None: raise EnvironmentError(f'Gradio required for launching webui-based demo.\n' @@ -57,25 +54,30 @@ def _check_gradio_env(): def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): """ - Encode an image into a numpy array for model input. + Encode an image into a numpy array suitable for model input. - This function resizes the input image, converts it to RGB format, and optionally - normalizes the pixel values. + This function performs several preprocessing steps on the input image: + 1. Resizes the image to the specified dimensions + 2. Converts to RGB format + 3. Applies normalization if parameters are provided + 4. Returns the image in CHW (Channel, Height, Width) format - :param image: The input image to be encoded. + :param image: Input PIL Image to be encoded :type image: Image.Image - :param size: The target size (width, height) to resize the image to, defaults to (384, 384). - :type size: Tuple[int, int], optional - :param normalize: The mean and standard deviation for normalization, defaults to (0.5, 0.5). - If None, no normalization is applied. - :type normalize: Optional[Tuple[float, float]], optional + :param size: Target dimensions (width, height) for resizing, defaults to (384, 384) + :type size: Tuple[int, int] + :param normalize: Normalization parameters (mean, std), defaults to (0.5, 0.5) + :type normalize: Optional[Tuple[float, float]] - :return: The encoded image as a numpy array in CHW format. + :return: Encoded and preprocessed image as numpy array :rtype: np.ndarray - :raises TypeError: If the input image is not a PIL Image object. + :raises TypeError: If input is not a PIL Image + + Example: + >>> img = Image.open('example.jpg') + >>> encoded = _img_encode(img, size=(224, 224)) """ - # noinspection PyUnresolvedReferences image = image.resize(size, Image.BILINEAR) data = rgb_encode(image, order_='CHW') @@ -93,39 +95,38 @@ def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), class ClassifyModel: """ - A class for managing and using classification models. - - This class provides methods for loading classification models from a Hugging Face - repository, making predictions, and managing model resources. It supports multiple - models within a single repository and handles token-based authentication. + A comprehensive manager for classification models from Hugging Face repositories. - :param repo_id: The Hugging Face repository ID containing the classification models. + :param repo_id: Hugging Face repository identifier :type repo_id: str - :param hf_token: The Hugging Face API token for accessing private repositories, defaults to None. - :type hf_token: Optional[str], optional + :param fn_preprocess: Optional custom preprocessing function + :type fn_preprocess: Optional[ImagePreprocessFunc] + :param hf_token: Hugging Face authentication token + :type hf_token: Optional[str] - :ivar repo_id: The Hugging Face repository ID. - :ivar _model_names: Cached list of available model names in the repository. - :ivar _models: Dictionary of loaded ONNX models. - :ivar _labels: Dictionary of labels for each model. - :ivar _hf_token: The Hugging Face API token. + :ivar repo_id: Repository identifier + :ivar _model_names: Cached list of available models + :ivar _models: Dictionary of loaded ONNX models + :ivar _labels: Dictionary of model labels + :ivar _hf_token: Authentication token Usage: - >>> model = ClassifyModel("username/repo_name") - >>> image = Image.open("path/to/image.jpg") - >>> prediction, score = model.predict(image, "model_name") - >>> print(f"Predicted class: {prediction}, Score: {score}") + >>> classifier = ClassifyModel("org/model-repo") + >>> with Image.open("image.jpg") as img: + ... label = classifier.predict(img, "model-name") """ def __init__(self, repo_id: str, fn_preprocess: Optional[ImagePreprocessFunc] = None, hf_token: Optional[str] = None): """ - Initialize the ClassifyModel instance. + Initialize a new ClassifyModel instance. - :param repo_id: The repository ID containing the models. + :param repo_id: Hugging Face repository identifier :type repo_id: str - :param hf_token: The Hugging Face API token, defaults to None. - :type hf_token: Optional[str], optional + :param fn_preprocess: Optional custom preprocessing function + :type fn_preprocess: Optional[ImagePreprocessFunc] + :param hf_token: Authentication token for private repositories + :type hf_token: Optional[str] """ self.repo_id = repo_id self._fn_preprocess = fn_preprocess @@ -138,9 +139,11 @@ def __init__(self, repo_id: str, fn_preprocess: Optional[ImagePreprocessFunc] = def _get_hf_token(self) -> Optional[str]: """ - Get the Hugging Face token from the instance variable or environment variable. + Retrieve the Hugging Face authentication token. + + Checks both instance variable and environment for token presence. - :return: The Hugging Face token. + :return: Authentication token if available :rtype: Optional[str] """ return self._hf_token or os.environ.get('HF_TOKEN') @@ -148,15 +151,15 @@ def _get_hf_token(self) -> Optional[str]: @property def model_names(self) -> List[str]: """ - Get the list of available model names in the repository. + Get available model names from the repository. - This property lazily loads the model names from the Hugging Face repository - and caches them for future use. + This property implements lazy loading and caching of model names. + Thread-safe access to the model list is ensured via locks. - :return: The list of model names available in the repository. + :return: List of available model names :rtype: List[str] - :raises RuntimeError: If there's an error accessing the Hugging Face repository. + :raises RuntimeError: If repository access fails """ with self._global_lock: if self._model_names is None: @@ -174,12 +177,12 @@ def model_names(self) -> List[str]: def _check_model_name(self, model_name: str): """ - Check if the given model name is valid and available in the repository. + Validate model name availability in the repository. - :param model_name: The name of the model to check. + :param model_name: Name of the model to verify :type model_name: str - :raises ValueError: If the model name is not found in the list of available models. + :raises ValueError: If model name is not found in repository """ if model_name not in self.model_names: raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, ' @@ -187,17 +190,18 @@ def _check_model_name(self, model_name: str): def _open_model(self, model_name: str): """ - Open and cache the specified ONNX model. + Load and cache an ONNX model. - This method downloads the model if it's not already cached and opens it using ONNX runtime. + Implements thread-safe model loading with caching for improved performance. + Downloads model from Hugging Face if not locally available. - :param model_name: The name of the model to open. + :param model_name: Name of the model to load :type model_name: str - :return: The opened ONNX model. + :return: Loaded ONNX model :rtype: Any - :raises RuntimeError: If there's an error downloading or opening the model. + :raises RuntimeError: If model loading fails """ with self._model_lock: if model_name not in self._models: @@ -212,17 +216,18 @@ def _open_model(self, model_name: str): def _open_label(self, model_name: str) -> List[str]: """ - Open and cache the labels file for the specified model. + Load and cache model labels from metadata. - This method downloads the meta.json file containing the labels if it's not already cached. + Implements thread-safe loading of model labels with caching. + Downloads label metadata from Hugging Face if not locally available. - :param model_name: The name of the model whose labels to open. + :param model_name: Name of the model whose labels to load :type model_name: str - :return: The list of labels for the specified model. + :return: List of model labels :rtype: List[str] - :raises RuntimeError: If there's an error downloading or parsing the labels file. + :raises RuntimeError: If label loading fails """ with self._model_lock: if model_name not in self._labels: @@ -238,19 +243,23 @@ def _open_label(self, model_name: str) -> List[str]: def _raw_predict(self, image: ImageTyping, model_name: str): """ - Make a raw prediction on the specified image using the specified model. + Generate raw model predictions for an input image. - This method preprocesses the image, runs it through the model, and returns the raw output. + This method handles: + 1. Image loading and preprocessing + 2. Model input shape validation + 3. Custom preprocessing if specified + 4. Model inference - :param image: The input image to classify. + :param image: Input image for prediction :type image: ImageTyping - :param model_name: The name of the model to use for prediction. + :param model_name: Name of model to use :type model_name: str - :return: The raw prediction output from the model. + :return: Raw model output :rtype: np.ndarray - :raises RuntimeError: If the model's input shape is incompatible with the image. + :raises RuntimeError: If model input shape is incompatible """ image = load_image(image, force_background='white', mode='RGB') model = self._open_model(model_name)