Skip to content

Commit

Permalink
dev(narugo): merge 2 functions for overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Oct 17, 2023
1 parent bfaa8d1 commit f1b2ef5
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 51 deletions.
2 changes: 1 addition & 1 deletion imgutils/tagging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from .match import tag_match_suffix, tag_match_prefix, tag_match_full
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
from .overlap import drop_overlap_tags, drop_overlaps_for_dict
from .overlap import drop_overlap_tags
from .wd14 import get_wd14_tags
4 changes: 2 additions & 2 deletions imgutils/tagging/deepdanbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from .overlap import drop_overlap_tags
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model

Expand Down Expand Up @@ -124,7 +124,7 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
if drop_overlap:
general_res = drop_overlaps_for_dict(general_res)
general_res = drop_overlap_tags(general_res)

# Everything else is characters: pick anywhere prediction confidence > threshold
character_names = [labels[i] for i in character_indexes]
Expand Down
4 changes: 2 additions & 2 deletions imgutils/tagging/mldanbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model

Expand Down Expand Up @@ -109,5 +109,5 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,

general_tags = {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}
if drop_overlap:
general_tags = drop_overlaps_for_dict(general_tags)
general_tags = drop_overlap_tags(general_tags)
return general_tags
73 changes: 33 additions & 40 deletions imgutils/tagging/overlap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import json
from functools import lru_cache
from typing import Mapping, List
from typing import Mapping, List, Union

from huggingface_hub import hf_hub_download

Expand All @@ -26,7 +27,7 @@ def _get_overlap_tags() -> Mapping[str, List[str]]:
return data


def drop_overlap_tags(tags: List[str]) -> List[str]:
def drop_overlap_tags(tags: Union[List[str], Mapping[str, float]]) -> Union[List[str], Mapping[str, float]]:
"""
Drop overlapping tags from the given list of tags.
Expand All @@ -47,13 +48,35 @@ def drop_overlap_tags(tags: List[str]) -> List[str]:
... ]
>>> drop_overlap_tags(tags)
['1girl', 'solo', 'very_long_hair', 'red_hair', 'medium_breasts']
>>>
>>> tags = {
... '1girl': 0.8849405313291128,
... 'solo': 0.8548297594823425,
... 'long_hair': 0.03910296474461261,
... 'very_long_hair': 0.6615180440330748,
... 'red_hair': 0.21552028866308015,
... 'breasts': 0.3165260620737027,
... 'medium_breasts': 0.47744464927382957,
... }
>>> drop_overlap_tags(tags)
{
'1girl': 0.8849405313291128,
'solo': 0.8548297594823425,
'very_long_hair': 0.6615180440330748,
'red_hair': 0.21552028866308015,
'medium_breasts': 0.47744464927382957
}
"""
overlap_tags_dict = _get_overlap_tags()
result_tags = []
_origin_tags = copy.deepcopy(tags)
if isinstance(tags, dict):
tags = list(tags.keys())
tags_underscore = [tag.replace(' ', '_') for tag in tags]

tags: List[str]
tags_underscore: List[str]
for tag, tag_ in zip(tags, tags_underscore):

to_remove = False

# Case 1: If the tag is a key and some of the associated values are in tags
Expand All @@ -71,40 +94,10 @@ def drop_overlap_tags(tags: List[str]) -> List[str]:
if not to_remove:
result_tags.append(tag)

return result_tags


def drop_overlaps_for_dict(tags: Mapping[str, float]) -> Mapping[str, float]:
"""
Drop overlapping tags from the given dictionary of tags with confidence scores.
This function removes tags that have overlaps with other tags based on precomputed overlap information.
:param tags: A dictionary where keys are tags and values are confidence scores.
:type tags: Mapping[str, float]
:return: A dictionary with non-overlapping tags and their corresponding confidence scores.
:rtype: Mapping[str, float]
Examples::
>>> from imgutils.tagging import drop_overlaps_for_dict
>>>
>>> tags = {
... '1girl': 0.8849405313291128,
... 'solo': 0.8548297594823425,
... 'long_hair': 0.03910296474461261,
... 'very_long_hair': 0.6615180440330748,
... 'red_hair': 0.21552028866308015,
... 'breasts': 0.3165260620737027,
... 'medium_breasts': 0.47744464927382957,
... }
>>> drop_overlaps_for_dict(tags)
{
'1girl': 0.8849405313291128,
'solo': 0.8548297594823425,
'very_long_hair': 0.6615180440330748,
'red_hair': 0.21552028866308015,
'medium_breasts': 0.47744464927382957
}
"""
key_set = set(drop_overlap_tags(list(tags.keys())))
return {tag: confidence for tag, confidence in tags.items() if tag in key_set}
if isinstance(_origin_tags, list):
return result_tags
elif isinstance(_origin_tags, dict):
_rtags_set = set(result_tags)
return {key: value for key, value in _origin_tags.items() if key in _rtags_set}
else:
raise TypeError(f'Unknown tags type - {_origin_tags!r}.') # pragma: no cover
4 changes: 2 additions & 2 deletions imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import pandas as pd

from .overlap import drop_overlaps_for_dict
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model

Expand Down Expand Up @@ -152,7 +152,7 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
if drop_overlap:
general_res = drop_overlaps_for_dict(general_res)
general_res = drop_overlap_tags(general_res)

# Everything else is characters: pick anywhere prediction confidence > threshold
character_names = [labels[i] for i in character_indexes]
Expand Down
13 changes: 9 additions & 4 deletions test/tagging/test_overlap.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import pytest

from imgutils.tagging import drop_overlaps_for_dict, drop_overlap_tags
from imgutils.tagging import drop_overlap_tags


@pytest.mark.unittest
class TestTaggingOverlap:
def test_drop_overlap_tags(self):
def test_drop_overlap_tags(self, complex_dict_tags):
assert drop_overlap_tags(['1girl', 'solo', 'long_hair', 'very_long_hair', 'red_hair']) == \
['1girl', 'solo', 'very_long_hair', 'red_hair']

def test_drop_overlaps_for_dict_complex(self, complex_dict_tags):
assert drop_overlaps_for_dict(complex_dict_tags) == pytest.approx({
assert drop_overlap_tags(complex_dict_tags) == pytest.approx({
'1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698,
'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605,
'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948,
Expand All @@ -23,3 +22,9 @@ def test_drop_overlaps_for_dict_complex(self, complex_dict_tags):
'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727,
'clitoris': 0.5310801267623901
})

def test_drop_overlap_tags_invalid(self):
with pytest.raises(TypeError):
drop_overlap_tags(1)
with pytest.raises(TypeError):
drop_overlap_tags(None)

0 comments on commit f1b2ef5

Please sign in to comment.