From 2b862b6b5476048e107a83ca92857dcc365617fd Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Sun, 12 May 2024 20:49:50 +0800 Subject: [PATCH 1/2] dev(narugo): add extractor --- imgutils/metrics/dbaesthetic.py | 31 ++++--------------------------- imgutils/tagging/wd14.py | 26 ++++++++++++++++++++++---- imgutils/utils/__init__.py | 1 + imgutils/utils/format.py | 26 ++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 31 deletions(-) create mode 100644 imgutils/utils/format.py diff --git a/imgutils/metrics/dbaesthetic.py b/imgutils/metrics/dbaesthetic.py index 7b221e88dd5..cab24120fa9 100644 --- a/imgutils/metrics/dbaesthetic.py +++ b/imgutils/metrics/dbaesthetic.py @@ -19,8 +19,9 @@ import numpy as np from huggingface_hub import hf_hub_download -from imgutils.data import ImageTyping -from imgutils.generic import ClassifyModel +from ..data import ImageTyping +from ..generic import ClassifyModel +from ..utils import vreplace __all__ = [ 'anime_dbaesthetic', @@ -40,30 +41,6 @@ } -def _value_replace(v, mapping): - """ - Replaces values in a data structure using a mapping dictionary. - - :param v: The input data structure. - :type v: Any - :param mapping: A dictionary mapping values to replacement values. - :type mapping: Dict - :return: The modified data structure. - :rtype: Any - """ - if isinstance(v, (list, tuple)): - return type(v)([_value_replace(vitem, mapping) for vitem in v]) - elif isinstance(v, dict): - return type(v)({key: _value_replace(value, mapping) for key, value in v.items()}) - else: - try: - _ = hash(v) - except TypeError: # pragma: no cover - return v - else: - return mapping.get(v, v) - - class AestheticModel: """ A model for assessing the aesthetic quality of anime images. @@ -171,7 +148,7 @@ def get_aesthetic(self, image: ImageTyping, model_name: str, fmt=('label', 'perc score, confidence = self.get_aesthetic_score(image, model_name) percentile = self.score_to_percentile(score, model_name) label = self.percentile_to_label(percentile) - return _value_replace( + return vreplace( v=fmt, mapping={ 'label': label, diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index 37eaf1ea4ec..a911943a094 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -16,7 +16,7 @@ from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping -from ..utils import open_onnx_model +from ..utils import open_onnx_model, vreplace SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" @@ -64,7 +64,10 @@ def _get_wd14_model(model_name): :rtype: ONNXModel """ _version_support_check(model_name) - return open_onnx_model(hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME)) + return open_onnx_model(hf_hub_download( + repo_id='deepghs/wd14_tagger_with_embeddings', + filename=f'{MODEL_NAMES[model_name]}/model.onnx', + )) @lru_cache() @@ -133,6 +136,7 @@ def get_wd14_tags( character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, + fmt=('rating', 'general', 'character'), ): """ Overview: @@ -155,6 +159,9 @@ def get_wd14_tags( :type no_underline: bool :param drop_overlap: If True, drops overlapping tags. :type drop_overlap: bool + :param fmt: Return format, default is ``('rating', 'general', 'character')``. + ``embedding`` is also supported for feature extraction. + :type fmt: Any :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] @@ -189,8 +196,10 @@ def get_wd14_tags( image = _prepare_image_for_tagging(image, target_size) input_name = model.get_inputs()[0].name + assert len(model.get_outputs()) == 2 label_name = model.get_outputs()[0].name - preds = model.run([label_name], {input_name: image})[0] + emb_name = model.get_outputs()[1].name + preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) ratings_names = [labels[i] for i in rating_indexes] @@ -215,4 +224,13 @@ def get_wd14_tags( character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) - return rating, general_res, character_res + return vreplace( + fmt, + { + 'rating': rating, + 'general': general_res, + 'character': character_res, + 'tag': {**general_res, **character_res}, + 'embedding': embeddings, + } + ) diff --git a/imgutils/utils/__init__.py b/imgutils/utils/__init__.py index 37df85fbeb9..43479752294 100644 --- a/imgutils/utils/__init__.py +++ b/imgutils/utils/__init__.py @@ -3,6 +3,7 @@ Generic utilities for :mod:`imgutils`. """ from .area import * +from .format import * from .onnxruntime import * from .storage import * from .tqdm_ import * diff --git a/imgutils/utils/format.py b/imgutils/utils/format.py new file mode 100644 index 00000000000..a06c86d7ff5 --- /dev/null +++ b/imgutils/utils/format.py @@ -0,0 +1,26 @@ +__all__ = [ + 'vreplace', +] + + +def vreplace(v, mapping): + """ + Replaces values in a data structure using a mapping dictionary. + :param v: The input data structure. + :type v: Any + :param mapping: A dictionary mapping values to replacement values. + :type mapping: Dict + :return: The modified data structure. + :rtype: Any + """ + if isinstance(v, (list, tuple)): + return type(v)([vreplace(vitem, mapping) for vitem in v]) + elif isinstance(v, dict): + return type(v)({key: vreplace(value, mapping) for key, value in v.items()}) + else: + try: + _ = hash(v) + except TypeError: # pragma: no cover + return v + else: + return mapping.get(v, v) From f48fd2de2d101a4bfdce02762cf953d56f90f9d1 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Sun, 12 May 2024 21:09:35 +0800 Subject: [PATCH 2/2] dev(narugo): add embedding support --- imgutils/tagging/wd14.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index a911943a094..ed86b613126 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -231,6 +231,6 @@ def get_wd14_tags( 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, - 'embedding': embeddings, + 'embedding': embeddings[0], } )