diff --git a/docs/source/api_doc/tagging/wd14.rst b/docs/source/api_doc/tagging/wd14.rst index 79ed07df49..9a33acf160 100644 --- a/docs/source/api_doc/tagging/wd14.rst +++ b/docs/source/api_doc/tagging/wd14.rst @@ -20,3 +20,9 @@ convert_wd14_emb_to_prediction +denormalize_wd14_emb +---------------------------------------------- + +.. autofunction:: denormalize_wd14_emb + + diff --git a/imgutils/tagging/__init__.py b/imgutils/tagging/__init__.py index dc10bbea0a..81344acfd8 100644 --- a/imgutils/tagging/__init__.py +++ b/imgutils/tagging/__init__.py @@ -16,4 +16,4 @@ from .mldanbooru import get_mldanbooru_tags from .order import sort_tags from .overlap import drop_overlap_tags -from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction +from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index 12fcd05654..33c9ffa108 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -214,6 +214,9 @@ def _postprocess_embedding( :param fmt: The format of the output. :return: The post-processed results. """ + assert len(pred.shape) == len(embedding.shape) == 1, \ + f'Both pred and embeddings shapes should be 1-dim, ' \ + f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.' tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) labels = list(zip(tag_names, pred.astype(float))) @@ -356,6 +359,9 @@ def get_wd14_tags( ) +_DEFAULT_DENORMALIZER_NAME = 'mnum2_all' + + def convert_wd14_emb_to_prediction( emb: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME, @@ -366,55 +372,174 @@ def convert_wd14_emb_to_prediction( no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), + denormalize: bool = False, + denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME, ): """ - Convert WD14 embedding to understandable prediction result. + Convert WD14 embedding to understandable prediction result. This function can process both + single embeddings (1-dimensional array) and batches of embeddings (2-dimensional array). - :param emb: The 1-dim extracted embedding. + :param emb: The extracted embedding(s). Can be either a 1-dim array for single image or + 2-dim array for batch processing :type emb: numpy.ndarray - :param model_name: The name of the model to use. + :param model_name: Name of the WD14 model to use for prediction :type model_name: str - :param general_threshold: The threshold for general tags. + :param general_threshold: Confidence threshold for general tags (0.0 to 1.0) :type general_threshold: float - :param general_mcut_enabled: If True, applies MCut thresholding to general tags. + :param general_mcut_enabled: Enable MCut thresholding for general tags to improve prediction quality :type general_mcut_enabled: bool - :param character_threshold: The threshold for character tags. + :param character_threshold: Confidence threshold for character tags (0.0 to 1.0) :type character_threshold: float - :param character_mcut_enabled: If True, applies MCut thresholding to character tags. + :param character_mcut_enabled: Enable MCut thresholding for character tags to improve prediction quality :type character_mcut_enabled: bool - :param no_underline: If True, replaces underscores in tag names with spaces. + :param no_underline: Replace underscores with spaces in tag names for better readability :type no_underline: bool - :param drop_overlap: If True, drops overlapping tags. + :param drop_overlap: Remove overlapping tags to reduce redundancy :type drop_overlap: bool - :param fmt: Return format, default is ``('rating', 'general', 'character')``. - :return: Prediction result based on the provided fmt. + :param fmt: Specify return format structure for predictions, default is ``('rating', 'general', 'character')``. + :type fmt: tuple + :param denormalize: Whether to denormalize the embedding before prediction + :type denormalize: bool + :param denormalizer_name: Name of the denormalizer to use if denormalization is enabled + :type denormalizer_name: str + :return: For single embeddings: prediction result based on fmt. For batches: list of prediction results. .. note:: Only the embeddings not get normalized can be converted to understandable prediction result. + If normalized embeddings are provided, set ``denormalize=True`` to convert them back. + + For batch processing (2-dim input), returns a list where each element corresponds + to one embedding's predictions in the same format as single embedding output. Example: >>> import os + >>> import numpy as np >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction >>> - >>> # extract the feature embedding + >>> # extract the feature embedding, shape: (W, ) >>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding') >>> >>> # convert to understandable result >>> rating, general, character = convert_wd14_emb_to_prediction(embedding) >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')` + >>> + >>> # Batch processing, shape: (B, W) + >>> embeddings = np.stack([ + ... get_wd14_tags('img1.jpg', fmt='embedding'), + ... get_wd14_tags('img2.jpg', fmt='embedding'), + ... ]) + >>> # results will be a list of (rating, general, character) tuples + >>> results = convert_wd14_emb_to_prediction(embeddings) """ + if denormalize: + emb = denormalize_wd14_emb( + emb=emb, + model_name=model_name, + denormalizer_name=denormalizer_name, + ) + z_weights = _get_wd14_weights(model_name) weights, bias = z_weights['weights'], z_weights['bias'] pred = sigmoid(emb @ weights + bias) - return _postprocess_embedding( - pred=pred, - embedding=emb, + if len(emb.shape) == 1: + return _postprocess_embedding( + pred=pred, + embedding=emb, + model_name=model_name, + general_threshold=general_threshold, + general_mcut_enabled=general_mcut_enabled, + character_threshold=character_threshold, + character_mcut_enabled=character_mcut_enabled, + no_underline=no_underline, + drop_overlap=drop_overlap, + fmt=fmt, + ) + else: + return [ + _postprocess_embedding( + pred=pred_item, + embedding=emb_item, + model_name=model_name, + general_threshold=general_threshold, + general_mcut_enabled=general_mcut_enabled, + character_threshold=character_threshold, + character_mcut_enabled=character_mcut_enabled, + no_underline=no_underline, + drop_overlap=drop_overlap, + fmt=fmt, + ) + for pred_item, emb_item in zip(pred, emb) + ] + + +@ts_lru_cache() +def _open_denormalize_model( + model_name: str = _DEFAULT_MODEL_NAME, + denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME, +): + """ + Open a denormalization model for WD14 embeddings. + + :param model_name: Name of the model. + :type model_name: str + :param denormalizer_name: Name of the denormalizer. + :type denormalizer_name: str + :return: The loaded ONNX model. + :rtype: ONNXModel + """ + return open_onnx_model(hf_hub_download( + repo_id='deepghs/embedding_aligner', + repo_type='model', + filename=f'{model_name}_{denormalizer_name}/model.onnx', + )) + + +def denormalize_wd14_emb( + emb: np.ndarray, + model_name: str = _DEFAULT_MODEL_NAME, + denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME, +) -> np.ndarray: + """ + Denormalize WD14 embeddings. + + :param emb: The embedding to denormalize. + :type emb: numpy.ndarray + :param model_name: Name of the model. + :type model_name: str + :param denormalizer_name: Name of the denormalizer. + :type denormalizer_name: str + :return: The denormalized embedding. + :rtype: numpy.ndarray + + Examples: + >>> import numpy as np + >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb + ... + >>> embedding, (r, g, c) = get_wd14_tags( + ... 'image.png', + ... fmt=('embedding', ('rating', 'general', 'character')), + ... ) + ... + >>> # normalize embedding + >>> embedding = embedding / np.linalg.norm(embedding) + ... + >>> # denormalize this embedding + >>> output = denormalize_wd14_emb(embedding) + ... + >>> # should be similar to r, g, c, approx 1e-3 error + >>> rating, general, character = convert_wd14_emb_to_prediction(output) + """ + model = _open_denormalize_model( model_name=model_name, - general_threshold=general_threshold, - general_mcut_enabled=general_mcut_enabled, - character_threshold=character_threshold, - character_mcut_enabled=character_mcut_enabled, - no_underline=no_underline, - drop_overlap=drop_overlap, - fmt=fmt, + denormalizer_name=denormalizer_name, ) + emb = emb / np.linalg.norm(emb, axis=-1, keepdims=True) + if len(emb.shape) == 1: + output, = model.run(output_names=['embedding'], input_feed={'input': emb[None, ...]}) + return output[0] + else: + embedding_width = model.get_outputs()[0].shape[-1] + origin_shape = emb.shape + emb = emb.reshape(-1, embedding_width) + output, = model.run(output_names=['embedding'], input_feed={'input': emb}) + return output.reshape(*origin_shape) diff --git a/test/tagging/test_wd14.py b/test/tagging/test_wd14.py index cc417cdf3a..289d20efad 100644 --- a/test/tagging/test_wd14.py +++ b/test/tagging/test_wd14.py @@ -1,7 +1,8 @@ +import numpy as np import pytest from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction -from imgutils.tagging.wd14 import _get_wd14_model +from imgutils.tagging.wd14 import _get_wd14_model, denormalize_wd14_emb, _open_denormalize_model from test.testings import get_testfile @@ -11,6 +12,7 @@ def _release_model_after_run(): yield finally: _get_wd14_model.cache_clear() + _open_denormalize_model.cache_clear() @pytest.mark.unittest @@ -173,3 +175,57 @@ def test_convert_wd14_emb_to_prediction(self, file): assert rating == pytest.approx(expected_rating, abs=2e-3) assert general == pytest.approx(expected_general, abs=2e-3) assert character == pytest.approx(expected_character, abs=2e-3) + + @pytest.mark.parametrize(['file'], [ + ('nude_girl.png',), + ]) + def test_convert_wd14_emb_to_prediction_denormalize(self, file): + file = get_testfile(file) + (expected_rating, expected_general, expected_character), embedding = \ + get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding')) + + embedding = embedding / np.linalg.norm(embedding) + rating, general, character = convert_wd14_emb_to_prediction(embedding, denormalize=True) + assert rating == pytest.approx(expected_rating, abs=1e-2) + assert general == pytest.approx(expected_general, abs=1e-2) + assert character == pytest.approx(expected_character, abs=1e-2) + + @pytest.mark.parametrize(['file'], [ + ('nude_girl.png',), + # ('nian.png',), # some low scores not match + ]) + def test_denormalize_wd14_emb(self, file): + file = get_testfile(file) + (expected_rating, expected_general, expected_character), embedding = \ + get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding')) + + embedding = embedding / np.linalg.norm(embedding) + output = denormalize_wd14_emb(embedding) + rating, general, character = convert_wd14_emb_to_prediction(output) + assert rating == pytest.approx(expected_rating, abs=1e-2) + assert general == pytest.approx(expected_general, abs=1e-2) + assert character == pytest.approx(expected_character, abs=1e-2) + + @pytest.mark.parametrize(['files'], [ + (['nude_girl.png'],), + (['nude_girl.png', 'nude_girl.png'],), + # ('nian.png',), # some low scores not match + ]) + def test_denormalize_wd14_emb_multiple(self, files): + files = [get_testfile(file) for file in files] + expected = [] + embeddings = [] + for file in files: + (expected_rating, expected_general, expected_character), embedding = \ + get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding')) + expected.append((expected_rating, expected_general, expected_character)) + embeddings.append(embedding / np.linalg.norm(embedding)) + + embeddings = np.stack(embeddings) + outputs = denormalize_wd14_emb(embeddings) + actual = convert_wd14_emb_to_prediction(outputs) + for (expected_rating, expected_general, expected_character), \ + (rating, general, character) in zip(expected, actual): + assert rating == pytest.approx(expected_rating, abs=1e-2) + assert general == pytest.approx(expected_general, abs=1e-2) + assert character == pytest.approx(expected_character, abs=1e-2)