Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev(narugo): add extractor #91

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 4 additions & 27 deletions imgutils/metrics/dbaesthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import numpy as np
from huggingface_hub import hf_hub_download

from imgutils.data import ImageTyping
from imgutils.generic import ClassifyModel
from ..data import ImageTyping
from ..generic import ClassifyModel
from ..utils import vreplace

__all__ = [
'anime_dbaesthetic',
Expand All @@ -40,30 +41,6 @@
}


def _value_replace(v, mapping):
"""
Replaces values in a data structure using a mapping dictionary.

:param v: The input data structure.
:type v: Any
:param mapping: A dictionary mapping values to replacement values.
:type mapping: Dict
:return: The modified data structure.
:rtype: Any
"""
if isinstance(v, (list, tuple)):
return type(v)([_value_replace(vitem, mapping) for vitem in v])
elif isinstance(v, dict):
return type(v)({key: _value_replace(value, mapping) for key, value in v.items()})
else:
try:
_ = hash(v)
except TypeError: # pragma: no cover
return v
else:
return mapping.get(v, v)


class AestheticModel:
"""
A model for assessing the aesthetic quality of anime images.
Expand Down Expand Up @@ -171,7 +148,7 @@ def get_aesthetic(self, image: ImageTyping, model_name: str, fmt=('label', 'perc
score, confidence = self.get_aesthetic_score(image, model_name)
percentile = self.score_to_percentile(score, model_name)
label = self.percentile_to_label(percentile)
return _value_replace(
return vreplace(
v=fmt,
mapping={
'label': label,
Expand Down
26 changes: 22 additions & 4 deletions imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model
from ..utils import open_onnx_model, vreplace

SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
Expand Down Expand Up @@ -64,7 +64,10 @@ def _get_wd14_model(model_name):
:rtype: ONNXModel
"""
_version_support_check(model_name)
return open_onnx_model(hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME))
return open_onnx_model(hf_hub_download(
repo_id='deepghs/wd14_tagger_with_embeddings',
filename=f'{MODEL_NAMES[model_name]}/model.onnx',
))


@lru_cache()
Expand Down Expand Up @@ -133,6 +136,7 @@ def get_wd14_tags(
character_mcut_enabled: bool = False,
no_underline: bool = False,
drop_overlap: bool = False,
fmt=('rating', 'general', 'character'),
):
"""
Overview:
Expand All @@ -155,6 +159,9 @@ def get_wd14_tags(
:type no_underline: bool
:param drop_overlap: If True, drops overlapping tags.
:type drop_overlap: bool
:param fmt: Return format, default is ``('rating', 'general', 'character')``.
``embedding`` is also supported for feature extraction.
:type fmt: Any
:return: A tuple containing dictionaries for rating, general, and character tags with their probabilities.
:rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]

Expand Down Expand Up @@ -189,8 +196,10 @@ def get_wd14_tags(
image = _prepare_image_for_tagging(image, target_size)

input_name = model.get_inputs()[0].name
assert len(model.get_outputs()) == 2
label_name = model.get_outputs()[0].name
preds = model.run([label_name], {input_name: image})[0]
emb_name = model.get_outputs()[1].name
preds, embeddings = model.run([label_name, emb_name], {input_name: image})
labels = list(zip(tag_names, preds[0].astype(float)))

ratings_names = [labels[i] for i in rating_indexes]
Expand All @@ -215,4 +224,13 @@ def get_wd14_tags(
character_res = [x for x in character_names if x[1] > character_threshold]
character_res = dict(character_res)

return rating, general_res, character_res
return vreplace(
fmt,
{
'rating': rating,
'general': general_res,
'character': character_res,
'tag': {**general_res, **character_res},
'embedding': embeddings[0],
}
)
1 change: 1 addition & 0 deletions imgutils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Generic utilities for :mod:`imgutils`.
"""
from .area import *
from .format import *
from .onnxruntime import *
from .storage import *
from .tqdm_ import *
26 changes: 26 additions & 0 deletions imgutils/utils/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
__all__ = [
'vreplace',
]


def vreplace(v, mapping):
"""
Replaces values in a data structure using a mapping dictionary.
:param v: The input data structure.
:type v: Any
:param mapping: A dictionary mapping values to replacement values.
:type mapping: Dict
:return: The modified data structure.
:rtype: Any
"""
if isinstance(v, (list, tuple)):
return type(v)([vreplace(vitem, mapping) for vitem in v])
elif isinstance(v, dict):
return type(v)({key: vreplace(value, mapping) for key, value in v.items()})
else:
try:
_ = hash(v)
except TypeError: # pragma: no cover
return v
else:
return mapping.get(v, v)
Loading