Skip to content

Commit

Permalink
Merge pull request #126 from deepghs/dev/denormalize
Browse files Browse the repository at this point in the history
dev(narugo): add de-normalizers for the embeddings of the wd14 taggers
  • Loading branch information
narugo1992 authored Nov 17, 2024
2 parents 510b6cc + 256758a commit d43dcae
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 24 deletions.
6 changes: 6 additions & 0 deletions docs/source/api_doc/tagging/wd14.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@ convert_wd14_emb_to_prediction



denormalize_wd14_emb
----------------------------------------------

.. autofunction:: denormalize_wd14_emb


2 changes: 1 addition & 1 deletion imgutils/tagging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
from .overlap import drop_overlap_tags
from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction
from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb
169 changes: 147 additions & 22 deletions imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def _postprocess_embedding(
:param fmt: The format of the output.
:return: The post-processed results.
"""
assert len(pred.shape) == len(embedding.shape) == 1, \
f'Both pred and embeddings shapes should be 1-dim, ' \
f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.'
tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline)
labels = list(zip(tag_names, pred.astype(float)))

Expand Down Expand Up @@ -356,6 +359,9 @@ def get_wd14_tags(
)


_DEFAULT_DENORMALIZER_NAME = 'mnum2_all'


def convert_wd14_emb_to_prediction(
emb: np.ndarray,
model_name: str = _DEFAULT_MODEL_NAME,
Expand All @@ -366,55 +372,174 @@ def convert_wd14_emb_to_prediction(
no_underline: bool = False,
drop_overlap: bool = False,
fmt=('rating', 'general', 'character'),
denormalize: bool = False,
denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
):
"""
Convert WD14 embedding to understandable prediction result.
Convert WD14 embedding to understandable prediction result. This function can process both
single embeddings (1-dimensional array) and batches of embeddings (2-dimensional array).
:param emb: The 1-dim extracted embedding.
:param emb: The extracted embedding(s). Can be either a 1-dim array for single image or
2-dim array for batch processing
:type emb: numpy.ndarray
:param model_name: The name of the model to use.
:param model_name: Name of the WD14 model to use for prediction
:type model_name: str
:param general_threshold: The threshold for general tags.
:param general_threshold: Confidence threshold for general tags (0.0 to 1.0)
:type general_threshold: float
:param general_mcut_enabled: If True, applies MCut thresholding to general tags.
:param general_mcut_enabled: Enable MCut thresholding for general tags to improve prediction quality
:type general_mcut_enabled: bool
:param character_threshold: The threshold for character tags.
:param character_threshold: Confidence threshold for character tags (0.0 to 1.0)
:type character_threshold: float
:param character_mcut_enabled: If True, applies MCut thresholding to character tags.
:param character_mcut_enabled: Enable MCut thresholding for character tags to improve prediction quality
:type character_mcut_enabled: bool
:param no_underline: If True, replaces underscores in tag names with spaces.
:param no_underline: Replace underscores with spaces in tag names for better readability
:type no_underline: bool
:param drop_overlap: If True, drops overlapping tags.
:param drop_overlap: Remove overlapping tags to reduce redundancy
:type drop_overlap: bool
:param fmt: Return format, default is ``('rating', 'general', 'character')``.
:return: Prediction result based on the provided fmt.
:param fmt: Specify return format structure for predictions, default is ``('rating', 'general', 'character')``.
:type fmt: tuple
:param denormalize: Whether to denormalize the embedding before prediction
:type denormalize: bool
:param denormalizer_name: Name of the denormalizer to use if denormalization is enabled
:type denormalizer_name: str
:return: For single embeddings: prediction result based on fmt. For batches: list of prediction results.
.. note::
Only the embeddings not get normalized can be converted to understandable prediction result.
If normalized embeddings are provided, set ``denormalize=True`` to convert them back.
For batch processing (2-dim input), returns a list where each element corresponds
to one embedding's predictions in the same format as single embedding output.
Example:
>>> import os
>>> import numpy as np
>>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
>>>
>>> # extract the feature embedding
>>> # extract the feature embedding, shape: (W, )
>>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding')
>>>
>>> # convert to understandable result
>>> rating, general, character = convert_wd14_emb_to_prediction(embedding)
>>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')`
>>>
>>> # Batch processing, shape: (B, W)
>>> embeddings = np.stack([
... get_wd14_tags('img1.jpg', fmt='embedding'),
... get_wd14_tags('img2.jpg', fmt='embedding'),
... ])
>>> # results will be a list of (rating, general, character) tuples
>>> results = convert_wd14_emb_to_prediction(embeddings)
"""
if denormalize:
emb = denormalize_wd14_emb(
emb=emb,
model_name=model_name,
denormalizer_name=denormalizer_name,
)

z_weights = _get_wd14_weights(model_name)
weights, bias = z_weights['weights'], z_weights['bias']
pred = sigmoid(emb @ weights + bias)
return _postprocess_embedding(
pred=pred,
embedding=emb,
if len(emb.shape) == 1:
return _postprocess_embedding(
pred=pred,
embedding=emb,
model_name=model_name,
general_threshold=general_threshold,
general_mcut_enabled=general_mcut_enabled,
character_threshold=character_threshold,
character_mcut_enabled=character_mcut_enabled,
no_underline=no_underline,
drop_overlap=drop_overlap,
fmt=fmt,
)
else:
return [
_postprocess_embedding(
pred=pred_item,
embedding=emb_item,
model_name=model_name,
general_threshold=general_threshold,
general_mcut_enabled=general_mcut_enabled,
character_threshold=character_threshold,
character_mcut_enabled=character_mcut_enabled,
no_underline=no_underline,
drop_overlap=drop_overlap,
fmt=fmt,
)
for pred_item, emb_item in zip(pred, emb)
]


@ts_lru_cache()
def _open_denormalize_model(
model_name: str = _DEFAULT_MODEL_NAME,
denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
):
"""
Open a denormalization model for WD14 embeddings.
:param model_name: Name of the model.
:type model_name: str
:param denormalizer_name: Name of the denormalizer.
:type denormalizer_name: str
:return: The loaded ONNX model.
:rtype: ONNXModel
"""
return open_onnx_model(hf_hub_download(
repo_id='deepghs/embedding_aligner',
repo_type='model',
filename=f'{model_name}_{denormalizer_name}/model.onnx',
))


def denormalize_wd14_emb(
emb: np.ndarray,
model_name: str = _DEFAULT_MODEL_NAME,
denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
) -> np.ndarray:
"""
Denormalize WD14 embeddings.
:param emb: The embedding to denormalize.
:type emb: numpy.ndarray
:param model_name: Name of the model.
:type model_name: str
:param denormalizer_name: Name of the denormalizer.
:type denormalizer_name: str
:return: The denormalized embedding.
:rtype: numpy.ndarray
Examples:
>>> import numpy as np
>>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb
...
>>> embedding, (r, g, c) = get_wd14_tags(
... 'image.png',
... fmt=('embedding', ('rating', 'general', 'character')),
... )
...
>>> # normalize embedding
>>> embedding = embedding / np.linalg.norm(embedding)
...
>>> # denormalize this embedding
>>> output = denormalize_wd14_emb(embedding)
...
>>> # should be similar to r, g, c, approx 1e-3 error
>>> rating, general, character = convert_wd14_emb_to_prediction(output)
"""
model = _open_denormalize_model(
model_name=model_name,
general_threshold=general_threshold,
general_mcut_enabled=general_mcut_enabled,
character_threshold=character_threshold,
character_mcut_enabled=character_mcut_enabled,
no_underline=no_underline,
drop_overlap=drop_overlap,
fmt=fmt,
denormalizer_name=denormalizer_name,
)
emb = emb / np.linalg.norm(emb, axis=-1, keepdims=True)
if len(emb.shape) == 1:
output, = model.run(output_names=['embedding'], input_feed={'input': emb[None, ...]})
return output[0]
else:
embedding_width = model.get_outputs()[0].shape[-1]
origin_shape = emb.shape
emb = emb.reshape(-1, embedding_width)
output, = model.run(output_names=['embedding'], input_feed={'input': emb})
return output.reshape(*origin_shape)
58 changes: 57 additions & 1 deletion test/tagging/test_wd14.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest

from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
from imgutils.tagging.wd14 import _get_wd14_model
from imgutils.tagging.wd14 import _get_wd14_model, denormalize_wd14_emb, _open_denormalize_model
from test.testings import get_testfile


Expand All @@ -11,6 +12,7 @@ def _release_model_after_run():
yield
finally:
_get_wd14_model.cache_clear()
_open_denormalize_model.cache_clear()


@pytest.mark.unittest
Expand Down Expand Up @@ -173,3 +175,57 @@ def test_convert_wd14_emb_to_prediction(self, file):
assert rating == pytest.approx(expected_rating, abs=2e-3)
assert general == pytest.approx(expected_general, abs=2e-3)
assert character == pytest.approx(expected_character, abs=2e-3)

@pytest.mark.parametrize(['file'], [
('nude_girl.png',),
])
def test_convert_wd14_emb_to_prediction_denormalize(self, file):
file = get_testfile(file)
(expected_rating, expected_general, expected_character), embedding = \
get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))

embedding = embedding / np.linalg.norm(embedding)
rating, general, character = convert_wd14_emb_to_prediction(embedding, denormalize=True)
assert rating == pytest.approx(expected_rating, abs=1e-2)
assert general == pytest.approx(expected_general, abs=1e-2)
assert character == pytest.approx(expected_character, abs=1e-2)

@pytest.mark.parametrize(['file'], [
('nude_girl.png',),
# ('nian.png',), # some low scores not match
])
def test_denormalize_wd14_emb(self, file):
file = get_testfile(file)
(expected_rating, expected_general, expected_character), embedding = \
get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))

embedding = embedding / np.linalg.norm(embedding)
output = denormalize_wd14_emb(embedding)
rating, general, character = convert_wd14_emb_to_prediction(output)
assert rating == pytest.approx(expected_rating, abs=1e-2)
assert general == pytest.approx(expected_general, abs=1e-2)
assert character == pytest.approx(expected_character, abs=1e-2)

@pytest.mark.parametrize(['files'], [
(['nude_girl.png'],),
(['nude_girl.png', 'nude_girl.png'],),
# ('nian.png',), # some low scores not match
])
def test_denormalize_wd14_emb_multiple(self, files):
files = [get_testfile(file) for file in files]
expected = []
embeddings = []
for file in files:
(expected_rating, expected_general, expected_character), embedding = \
get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))
expected.append((expected_rating, expected_general, expected_character))
embeddings.append(embedding / np.linalg.norm(embedding))

embeddings = np.stack(embeddings)
outputs = denormalize_wd14_emb(embeddings)
actual = convert_wd14_emb_to_prediction(outputs)
for (expected_rating, expected_general, expected_character), \
(rating, general, character) in zip(expected, actual):
assert rating == pytest.approx(expected_rating, abs=1e-2)
assert general == pytest.approx(expected_general, abs=1e-2)
assert character == pytest.approx(expected_character, abs=1e-2)

0 comments on commit d43dcae

Please sign in to comment.