Skip to content

Commit

Permalink
Merge pull request #127 from deepghs/dev/preprocess
Browse files Browse the repository at this point in the history
dev(narugo): add preprocessor into the ClassifyModel
  • Loading branch information
narugo1992 authored Nov 25, 2024
2 parents d43dcae + cb17dac commit 9ba8f23
Showing 1 changed file with 87 additions and 70 deletions.
157 changes: 87 additions & 70 deletions imgutils/generic/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,17 @@
This module provides utilities and classes for working with classification models,
particularly those stored in Hugging Face repositories. It includes functions for
image encoding, model loading, and prediction, as well as a main `ClassifyModel` class
image encoding, model loading, and prediction, as well as a main ClassifyModel class
that manages the interaction with classification models.
Key components:
- Image encoding and preprocessing
- ClassifyModel: A class for managing and using classification models
- Utility functions for making predictions with classification models
The module is designed to work with ONNX models and supports various image input formats.
It also handles token-based authentication for accessing private Hugging Face repositories.
"""

import json
import os
from threading import Lock
from typing import Tuple, Optional, List, Dict
from typing import Tuple, Optional, List, Dict, Callable

import numpy as np
from PIL import Image
Expand All @@ -45,9 +39,12 @@

def _check_gradio_env():
"""
Check if the Gradio library is installed and available.
Verify that Gradio library is properly installed and available.
:raises EnvironmentError: If Gradio is not installed.
This function checks if the Gradio package is accessible for creating
web-based demos. If Gradio is not found, it provides instructions for installation.
:raises EnvironmentError: If Gradio package is not installed in the environment.
"""
if gr is None:
raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
Expand All @@ -57,25 +54,30 @@ def _check_gradio_env():
def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
"""
Encode an image into a numpy array for model input.
Encode an image into a numpy array suitable for model input.
This function resizes the input image, converts it to RGB format, and optionally
normalizes the pixel values.
This function performs several preprocessing steps on the input image:
1. Resizes the image to the specified dimensions
2. Converts to RGB format
3. Applies normalization if parameters are provided
4. Returns the image in CHW (Channel, Height, Width) format
:param image: The input image to be encoded.
:param image: Input PIL Image to be encoded
:type image: Image.Image
:param size: The target size (width, height) to resize the image to, defaults to (384, 384).
:type size: Tuple[int, int], optional
:param normalize: The mean and standard deviation for normalization, defaults to (0.5, 0.5).
If None, no normalization is applied.
:type normalize: Optional[Tuple[float, float]], optional
:param size: Target dimensions (width, height) for resizing, defaults to (384, 384)
:type size: Tuple[int, int]
:param normalize: Normalization parameters (mean, std), defaults to (0.5, 0.5)
:type normalize: Optional[Tuple[float, float]]
:return: The encoded image as a numpy array in CHW format.
:return: Encoded and preprocessed image as numpy array
:rtype: np.ndarray
:raises TypeError: If the input image is not a PIL Image object.
:raises TypeError: If input is not a PIL Image
Example:
>>> img = Image.open('example.jpg')
>>> encoded = _img_encode(img, size=(224, 224))
"""
# noinspection PyUnresolvedReferences
image = image.resize(size, Image.BILINEAR)
data = rgb_encode(image, order_='CHW')

Expand All @@ -88,42 +90,46 @@ def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
return data.astype(np.float32)


ImagePreprocessFunc = Callable[[Image.Image], Image.Image]


class ClassifyModel:
"""
A class for managing and using classification models.
This class provides methods for loading classification models from a Hugging Face
repository, making predictions, and managing model resources. It supports multiple
models within a single repository and handles token-based authentication.
A comprehensive manager for classification models from Hugging Face repositories.
:param repo_id: The Hugging Face repository ID containing the classification models.
:param repo_id: Hugging Face repository identifier
:type repo_id: str
:param hf_token: The Hugging Face API token for accessing private repositories, defaults to None.
:type hf_token: Optional[str], optional
:param fn_preprocess: Optional custom preprocessing function
:type fn_preprocess: Optional[ImagePreprocessFunc]
:param hf_token: Hugging Face authentication token
:type hf_token: Optional[str]
:ivar repo_id: The Hugging Face repository ID.
:ivar _model_names: Cached list of available model names in the repository.
:ivar _models: Dictionary of loaded ONNX models.
:ivar _labels: Dictionary of labels for each model.
:ivar _hf_token: The Hugging Face API token.
:ivar repo_id: Repository identifier
:ivar _model_names: Cached list of available models
:ivar _models: Dictionary of loaded ONNX models
:ivar _labels: Dictionary of model labels
:ivar _hf_token: Authentication token
Usage:
>>> model = ClassifyModel("username/repo_name")
>>> image = Image.open("path/to/image.jpg")
>>> prediction, score = model.predict(image, "model_name")
>>> print(f"Predicted class: {prediction}, Score: {score}")
>>> classifier = ClassifyModel("org/model-repo")
>>> with Image.open("image.jpg") as img:
... label = classifier.predict(img, "model-name")
"""

def __init__(self, repo_id: str, hf_token: Optional[str] = None):
def __init__(self, repo_id: str, fn_preprocess: Optional[ImagePreprocessFunc] = None,
hf_token: Optional[str] = None):
"""
Initialize the ClassifyModel instance.
Initialize a new ClassifyModel instance.
:param repo_id: The repository ID containing the models.
:param repo_id: Hugging Face repository identifier
:type repo_id: str
:param hf_token: The Hugging Face API token, defaults to None.
:type hf_token: Optional[str], optional
:param fn_preprocess: Optional custom preprocessing function
:type fn_preprocess: Optional[ImagePreprocessFunc]
:param hf_token: Authentication token for private repositories
:type hf_token: Optional[str]
"""
self.repo_id = repo_id
self._fn_preprocess = fn_preprocess
self._model_names = None
self._models = {}
self._labels = {}
Expand All @@ -133,25 +139,27 @@ def __init__(self, repo_id: str, hf_token: Optional[str] = None):

def _get_hf_token(self) -> Optional[str]:
"""
Get the Hugging Face token from the instance variable or environment variable.
Retrieve the Hugging Face authentication token.
Checks both instance variable and environment for token presence.
:return: The Hugging Face token.
:return: Authentication token if available
:rtype: Optional[str]
"""
return self._hf_token or os.environ.get('HF_TOKEN')

@property
def model_names(self) -> List[str]:
"""
Get the list of available model names in the repository.
Get available model names from the repository.
This property lazily loads the model names from the Hugging Face repository
and caches them for future use.
This property implements lazy loading and caching of model names.
Thread-safe access to the model list is ensured via locks.
:return: The list of model names available in the repository.
:return: List of available model names
:rtype: List[str]
:raises RuntimeError: If there's an error accessing the Hugging Face repository.
:raises RuntimeError: If repository access fails
"""
with self._global_lock:
if self._model_names is None:
Expand All @@ -169,30 +177,31 @@ def model_names(self) -> List[str]:

def _check_model_name(self, model_name: str):
"""
Check if the given model name is valid and available in the repository.
Validate model name availability in the repository.
:param model_name: The name of the model to check.
:param model_name: Name of the model to verify
:type model_name: str
:raises ValueError: If the model name is not found in the list of available models.
:raises ValueError: If model name is not found in repository
"""
if model_name not in self.model_names:
raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
f'models {self.model_names!r} are available.')

def _open_model(self, model_name: str):
"""
Open and cache the specified ONNX model.
Load and cache an ONNX model.
This method downloads the model if it's not already cached and opens it using ONNX runtime.
Implements thread-safe model loading with caching for improved performance.
Downloads model from Hugging Face if not locally available.
:param model_name: The name of the model to open.
:param model_name: Name of the model to load
:type model_name: str
:return: The opened ONNX model.
:return: Loaded ONNX model
:rtype: Any
:raises RuntimeError: If there's an error downloading or opening the model.
:raises RuntimeError: If model loading fails
"""
with self._model_lock:
if model_name not in self._models:
Expand All @@ -207,17 +216,18 @@ def _open_model(self, model_name: str):

def _open_label(self, model_name: str) -> List[str]:
"""
Open and cache the labels file for the specified model.
Load and cache model labels from metadata.
This method downloads the meta.json file containing the labels if it's not already cached.
Implements thread-safe loading of model labels with caching.
Downloads label metadata from Hugging Face if not locally available.
:param model_name: The name of the model whose labels to open.
:param model_name: Name of the model whose labels to load
:type model_name: str
:return: The list of labels for the specified model.
:return: List of model labels
:rtype: List[str]
:raises RuntimeError: If there's an error downloading or parsing the labels file.
:raises RuntimeError: If label loading fails
"""
with self._model_lock:
if model_name not in self._labels:
Expand All @@ -233,19 +243,23 @@ def _open_label(self, model_name: str) -> List[str]:

def _raw_predict(self, image: ImageTyping, model_name: str):
"""
Make a raw prediction on the specified image using the specified model.
Generate raw model predictions for an input image.
This method preprocesses the image, runs it through the model, and returns the raw output.
This method handles:
1. Image loading and preprocessing
2. Model input shape validation
3. Custom preprocessing if specified
4. Model inference
:param image: The input image to classify.
:param image: Input image for prediction
:type image: ImageTyping
:param model_name: The name of the model to use for prediction.
:param model_name: Name of model to use
:type model_name: str
:return: The raw prediction output from the model.
:return: Raw model output
:rtype: np.ndarray
:raises RuntimeError: If the model's input shape is incompatible with the image.
:raises RuntimeError: If model input shape is incompatible
"""
image = load_image(image, force_background='white', mode='RGB')
model = self._open_model(model_name)
Expand All @@ -254,6 +268,9 @@ def _raw_predict(self, image: ImageTyping, model_name: str):
raise RuntimeError(f'Model {model_name!r} required {[batch, channels, height, width]!r}, '
f'channels not 3.') # pragma: no cover

if self._fn_preprocess:
image = self._fn_preprocess(image)

if isinstance(height, int) and isinstance(width, int):
input_ = _img_encode(image, size=(width, height))[None, ...]
else:
Expand Down

0 comments on commit 9ba8f23

Please sign in to comment.