diff --git a/docs/source/api_doc/tagging/wd14.rst b/docs/source/api_doc/tagging/wd14.rst index 0cca46cbe39..79ed07df49f 100644 --- a/docs/source/api_doc/tagging/wd14.rst +++ b/docs/source/api_doc/tagging/wd14.rst @@ -12,3 +12,11 @@ get_wd14_tags .. autofunction:: get_wd14_tags + +convert_wd14_emb_to_prediction +---------------------------------------------- + +.. autofunction:: convert_wd14_emb_to_prediction + + + diff --git a/docs/source/api_doc/utils/func.rst b/docs/source/api_doc/utils/func.rst new file mode 100644 index 00000000000..db8ae2b9a65 --- /dev/null +++ b/docs/source/api_doc/utils/func.rst @@ -0,0 +1,14 @@ +imgutils.utils.func +==================================== + +.. currentmodule:: imgutils.utils.func + +.. automodule:: imgutils.utils.func + + +sigmoid +------------------------- + +.. autofunction:: sigmoid + + diff --git a/docs/source/api_doc/utils/index.rst b/docs/source/api_doc/utils/index.rst index 5cdd2d26ff1..326d2f6b95b 100644 --- a/docs/source/api_doc/utils/index.rst +++ b/docs/source/api_doc/utils/index.rst @@ -9,4 +9,5 @@ imgutils.utils .. toctree:: :maxdepth: 3 + func onnxruntime diff --git a/imgutils/tagging/__init__.py b/imgutils/tagging/__init__.py index ab92f09bc27..dc10bbea0ae 100644 --- a/imgutils/tagging/__init__.py +++ b/imgutils/tagging/__init__.py @@ -16,4 +16,4 @@ from .mldanbooru import get_mldanbooru_tags from .order import sort_tags from .overlap import drop_overlap_tags -from .wd14 import get_wd14_tags +from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index 0d00aad1666..a7783635fcf 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -1,10 +1,15 @@ """ Overview: - Tagging utils based on wd14 v2, inspired by - `SmilingWolf/wd-v1-4-tags `_ . + This module provides utilities for image tagging using WD14 taggers. + It includes functions for loading models, processing images, and extracting tags. + + The module is inspired by the `SmilingWolf/wd-v1-4-tags `_ + project on Hugging Face. + """ + from functools import lru_cache -from typing import List, Tuple, Dict +from typing import List, Tuple import numpy as np import onnxruntime @@ -15,8 +20,8 @@ from .format import remove_underline from .overlap import drop_overlap_tags -from ..data import load_image, ImageTyping, has_alpha_channel -from ..utils import open_onnx_model, vreplace +from ..data import load_image, ImageTyping +from ..utils import open_onnx_model, vreplace, sigmoid SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" @@ -51,6 +56,13 @@ def _version_support_check(model_name): + """ + Check if the current onnxruntime version supports the given model. + + :param model_name: The name of the model to check. + :type model_name: str + :raises EnvironmentError: If the model is not supported by the current onnxruntime version. + """ if model_name.endswith('_v3') and not _IS_V3_SUPPORT: raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, ' f'please upgrade it to 1.17+ version.\n' @@ -63,7 +75,7 @@ def _get_wd14_model(model_name): """ Load an ONNX model from the Hugging Face Hub. - :param model_name: The name of the model. + :param model_name: The name of the model to load. :type model_name: str :return: The loaded ONNX model. :rtype: ONNXModel @@ -75,6 +87,23 @@ def _get_wd14_model(model_name): )) +@lru_cache() +def _get_wd14_weights(model_name): + """ + Load the weights for a WD14 model. + + :param model_name: The name of the model. + :type model_name: str + :return: The loaded weights. + :rtype: numpy.ndarray + """ + _version_support_check(model_name) + return np.load(hf_hub_download( + repo_id='deepghs/wd14_tagger_with_embeddings', + filename=f'{MODEL_NAMES[model_name]}/inv.npz', + )) + + @lru_cache() def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]: """ @@ -102,10 +131,17 @@ def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], def _mcut_threshold(probs) -> float: """ - Maximum Cut Thresholding (MCut) + Compute the Maximum Cut Thresholding (MCut) for multi-label classification. + + This method is based on the paper: Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy - for Multi-label Classification. In 11th International Symposium, IDA 2012 - (pp. 172-183). + for Multi-label Classification. In 11th International Symposium, IDA 2012 + (pp. 172-183). + + :param probs: Array of probabilities. + :type probs: numpy.ndarray + :return: The computed threshold. + :rtype: float """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] @@ -115,6 +151,16 @@ def _mcut_threshold(probs) -> float: def _prepare_image_for_tagging(image: ImageTyping, target_size: int): + """ + Prepare an image for tagging by resizing and padding it. + + :param image: The input image. + :type image: ImageTyping + :param target_size: The target size for the image. + :type target_size: int + :return: The prepared image as a numpy array. + :rtype: numpy.ndarray + """ image = load_image(image, force_background=None, mode=None) image_shape = image.size max_dim = max(image_shape) @@ -135,6 +181,76 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int): return np.expand_dims(image_array, axis=0) +def _postprocess_embedding( + pred, embedding, + model_name: str = _DEFAULT_MODEL_NAME, + general_threshold: float = 0.35, + general_mcut_enabled: bool = False, + character_threshold: float = 0.85, + character_mcut_enabled: bool = False, + no_underline: bool = False, + drop_overlap: bool = False, + fmt=('rating', 'general', 'character'), +): + """ + Post-process the embedding and prediction results. + + :param pred: The prediction array. + :type pred: numpy.ndarray + :param embedding: The embedding array. + :type embedding: numpy.ndarray + :param model_name: The name of the model used. + :type model_name: str + :param general_threshold: Threshold for general tags. + :type general_threshold: float + :param general_mcut_enabled: Whether to use MCut for general tags. + :type general_mcut_enabled: bool + :param character_threshold: Threshold for character tags. + :type character_threshold: float + :param character_mcut_enabled: Whether to use MCut for character tags. + :type character_mcut_enabled: bool + :param no_underline: Whether to remove underscores from tag names. + :type no_underline: bool + :param drop_overlap: Whether to drop overlapping tags. + :type drop_overlap: bool + :param fmt: The format of the output. + :return: The post-processed results. + """ + tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) + labels = list(zip(tag_names, pred.astype(float))) + + rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} + + general_names = [labels[i] for i in general_indexes] + if general_mcut_enabled: + general_probs = np.array([x[1] for x in general_names]) + general_threshold = _mcut_threshold(general_probs) + + general_res = {x: v.item() for x, v in general_names if v > general_threshold} + if drop_overlap: + general_res = drop_overlap_tags(general_res) + + character_names = [labels[i] for i in character_indexes] + if character_mcut_enabled: + character_probs = np.array([x[1] for x in character_names]) + character_threshold = _mcut_threshold(character_probs) + character_threshold = max(0.15, character_threshold) + + character_res = {x: v.item() for x, v in character_names if v > character_threshold} + + return vreplace( + fmt, + { + 'rating': rating, + 'general': general_res, + 'character': character_res, + 'tag': {**general_res, **character_res}, + 'embedding': embedding.astype(np.float32), + 'prediction': pred.astype(np.float32), + } + ) + + def get_wd14_tags( image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, @@ -147,9 +263,10 @@ def get_wd14_tags( fmt=('rating', 'general', 'character'), ): """ - Overview: - Get tags for an image with wd14 taggers. - Similar to `SmilingWolf/wd-v1-4-tags `_ . + Get tags for an image using WD14 taggers. + + This function is similar to the + `SmilingWolf/wd-v1-4-tags `_ project on Hugging Face. :param image: The input image. :type image: ImageTyping @@ -169,19 +286,28 @@ def get_wd14_tags( :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. - :type fmt: Any - :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. - :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] + :return: Prediction result based on the provided fmt. .. note:: - About ``fmt`` argument, these are the available names: + The fmt argument can include the following keys: + + - ``rating``: a dict containing ratings and their confidences + - ``general``: a dict containing general tags and their confidences + - ``character``: a dict containing character tags and their confidences + - ``tag``: a dict containing all tags (including general and character, not including rating) and their confidences + - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization + - ``prediction``: a 1-dim prediction result of image - * ``rating``, a dict containing ratings and their confidences - * ``general``, a dict containing general tags and their confidences - * ``character``, a dict containing character tags and their confidences - * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences - * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization - * ``prediction``, a 1-dim prediction result of image + You can extract embedding of the given image with the follwing code + + >>> from imgutils.tagging import get_wd14_tags + >>> + >>> embedding = get_wd14_tags('skadi.jpg', fmt='embdding') + >>> embedding.shape + (1024, ) + + This embedding is valuable for constructing indices that enable rapid querying of images based on + visual features within large-scale datasets. Example: Here are some images for example @@ -189,7 +315,6 @@ def get_wd14_tags( .. image:: tagging_demo.plot.py.svg :align: center - >>> import os >>> from imgutils.tagging import get_wd14_tags >>> >>> rating, features, chars = get_wd14_tags('skadi.jpg') @@ -208,7 +333,7 @@ def get_wd14_tags( >>> chars {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541} """ - tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) + model = _get_wd14_model(model_name) _, target_size, _, _ = model.get_inputs()[0].shape image = _prepare_image_for_tagging(image, target_size) @@ -218,35 +343,80 @@ def get_wd14_tags( label_name = model.get_outputs()[0].name emb_name = model.get_outputs()[1].name preds, embeddings = model.run([label_name, emb_name], {input_name: image}) - labels = list(zip(tag_names, preds[0].astype(float))) - rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} + return _postprocess_embedding( + pred=preds[0], + embedding=embeddings[0], + model_name=model_name, + general_threshold=general_threshold, + general_mcut_enabled=general_mcut_enabled, + character_threshold=character_threshold, + character_mcut_enabled=character_mcut_enabled, + no_underline=no_underline, + drop_overlap=drop_overlap, + fmt=fmt, + ) - general_names = [labels[i] for i in general_indexes] - if general_mcut_enabled: - general_probs = np.array([x[1] for x in general_names]) - general_threshold = _mcut_threshold(general_probs) - general_res = {x: v.item() for x, v in general_names if v > general_threshold} - if drop_overlap: - general_res = drop_overlap_tags(general_res) +def convert_wd14_emb_to_prediction( + emb: np.ndarray, + model_name: str = _DEFAULT_MODEL_NAME, + general_threshold: float = 0.35, + general_mcut_enabled: bool = False, + character_threshold: float = 0.85, + character_mcut_enabled: bool = False, + no_underline: bool = False, + drop_overlap: bool = False, + fmt=('rating', 'general', 'character'), +): + """ + Convert WD14 embedding to understandable prediction result. - character_names = [labels[i] for i in character_indexes] - if character_mcut_enabled: - character_probs = np.array([x[1] for x in character_names]) - character_threshold = _mcut_threshold(character_probs) - character_threshold = max(0.15, character_threshold) + :param emb: The 1-dim extracted embedding. + :type emb: numpy.ndarray + :param model_name: The name of the model to use. + :type model_name: str + :param general_threshold: The threshold for general tags. + :type general_threshold: float + :param general_mcut_enabled: If True, applies MCut thresholding to general tags. + :type general_mcut_enabled: bool + :param character_threshold: The threshold for character tags. + :type character_threshold: float + :param character_mcut_enabled: If True, applies MCut thresholding to character tags. + :type character_mcut_enabled: bool + :param no_underline: If True, replaces underscores in tag names with spaces. + :type no_underline: bool + :param drop_overlap: If True, drops overlapping tags. + :type drop_overlap: bool + :param fmt: Return format, default is ``('rating', 'general', 'character')``. + :return: Prediction result based on the provided fmt. - character_res = {x: v.item() for x, v in character_names if v > character_threshold} + .. note:: + Only the embeddings not get normalized can be converted to understandable prediction result. - return vreplace( - fmt, - { - 'rating': rating, - 'general': general_res, - 'character': character_res, - 'tag': {**general_res, **character_res}, - 'embedding': embeddings[0].astype(np.float32), - 'prediction': preds[0].astype(np.float32), - } + Example: + >>> import os + >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction + >>> + >>> # extract the feature embedding + >>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding') + >>> + >>> # convert to understandable result + >>> rating, general, character = convert_wd14_emb_to_prediction(embedding) + >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')` + """ + z_weights = _get_wd14_weights(model_name) + weights, bias = z_weights['weights'], z_weights['bias'] + pred = sigmoid(emb @ weights + bias) + return _postprocess_embedding( + pred=pred, + embedding=emb, + model_name=model_name, + general_threshold=general_threshold, + general_mcut_enabled=general_mcut_enabled, + character_threshold=character_threshold, + character_mcut_enabled=character_mcut_enabled, + no_underline=no_underline, + drop_overlap=drop_overlap, + fmt=fmt, ) diff --git a/imgutils/utils/__init__.py b/imgutils/utils/__init__.py index 43479752294..ce7d2b266f0 100644 --- a/imgutils/utils/__init__.py +++ b/imgutils/utils/__init__.py @@ -4,6 +4,7 @@ """ from .area import * from .format import * +from .func import * from .onnxruntime import * from .storage import * from .tqdm_ import * diff --git a/imgutils/utils/func.py b/imgutils/utils/func.py new file mode 100644 index 00000000000..989ab4c84df --- /dev/null +++ b/imgutils/utils/func.py @@ -0,0 +1,42 @@ +""" +This module provides mathematical functions related to neural networks. + +It includes the sigmoid activation function, which is commonly used in various +machine learning and deep learning models. The sigmoid function maps any input +value to a value between 0 and 1, making it useful for binary classification +problems and as an activation function in neural network layers. + +Usage: + >>> from imgutils.utils import sigmoid + >>> result = sigmoid(input_value) +""" + +import numpy as np + +__all__ = ['sigmoid'] + + +def sigmoid(x): + """ + Compute the sigmoid function for the input. + + The sigmoid function is defined as: + :math:`f\\left(x\\right) = \\frac{1}{1 + e^{-x}}` + + This function applies the sigmoid activation to either a single number + or an array of numbers using NumPy for efficient computation. + + :param x: Input value or array of values. + :type x: float or numpy.ndarray + + :return: Sigmoid of the input. + :rtype: float or numpy.ndarray + + :example: + >>> import numpy as np + >>> sigmoid(0) + 0.5 + >>> sigmoid(np.array([-1, 0, 1])) + array([0.26894142, 0.5 , 0.73105858]) + """ + return 1 / (1 + np.exp(-x)) diff --git a/test/tagging/test_wd14.py b/test/tagging/test_wd14.py index cbaf33efa11..cc417cdf3aa 100644 --- a/test/tagging/test_wd14.py +++ b/test/tagging/test_wd14.py @@ -1,6 +1,6 @@ import pytest -from imgutils.tagging import get_wd14_tags +from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction from imgutils.tagging.wd14 import _get_wd14_model from test.testings import get_testfile @@ -159,3 +159,17 @@ def test_wd14_rgba(self): 'tube_top': 0.9783295392990112, 'bead_bracelet': 0.3510066270828247, 'red_bandeau': 0.8741766214370728 }, abs=2e-2) assert chars == pytest.approx({'nian_(arknights)': 0.9968841671943665}, abs=2e-2) + + @pytest.mark.parametrize(['file'], [ + ('nude_girl.png',), + ('nian.png',), + ]) + def test_convert_wd14_emb_to_prediction(self, file): + file = get_testfile(file) + (expected_rating, expected_general, expected_character), embedding = \ + get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding')) + + rating, general, character = convert_wd14_emb_to_prediction(embedding) + assert rating == pytest.approx(expected_rating, abs=2e-3) + assert general == pytest.approx(expected_general, abs=2e-3) + assert character == pytest.approx(expected_character, abs=2e-3) diff --git a/test/utils/__init__.py b/test/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/utils/test_func.py b/test/utils/test_func.py new file mode 100644 index 00000000000..f3b9df06b03 --- /dev/null +++ b/test/utils/test_func.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest + +from imgutils.utils import sigmoid + + +@pytest.fixture +def sample_input(): + return np.array([-1, 0, 1]) + + +@pytest.fixture +def expected_output(): + return np.array([0.26894142, 0.5, 0.73105858]) + + +@pytest.mark.unittest +class TestUtilsFuncSigmoid: + def test_sigmoid_scalar(self): + assert np.isclose(sigmoid(0), 0.5) + + def test_sigmoid_array(self, sample_input, expected_output): + result = sigmoid(sample_input) + np.testing.assert_array_almost_equal(result, expected_output) + + def test_sigmoid_large_positive(self): + assert np.isclose(sigmoid(100), 1.0) + + def test_sigmoid_large_negative(self): + assert np.isclose(sigmoid(-100), 0.0) + + def test_sigmoid_zero(self): + assert sigmoid(0) == 0.5 + + def test_sigmoid_type(self): + assert isinstance(sigmoid(1), float) + assert isinstance(sigmoid(np.array([1])), np.ndarray) + + def test_sigmoid_shape(self): + input_array = np.array([[1, 2], [3, 4]]) + result = sigmoid(input_array) + assert result.shape == input_array.shape