Skip to content

Commit

Permalink
dev(narugo): add overlap dropping for tags
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Oct 8, 2023
1 parent 17df867 commit 2c7cf25
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/api_doc/tagging/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ imgutils.tagging
wd14
deepdanbooru
format
overlap
22 changes: 22 additions & 0 deletions docs/source/api_doc/tagging/overlap.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
imgutils.tagging.overlap
====================================

.. currentmodule:: imgutils.tagging.overlap

.. automodule:: imgutils.tagging.overlap


drop_overlap_tags
----------------------------------

.. autofunction:: drop_overlap_tags



drop_overlaps_for_dict
----------------------------------

.. autofunction:: drop_overlaps_for_dict



1 change: 1 addition & 0 deletions imgutils/tagging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text
from .mldanbooru import get_mldanbooru_tags
from .overlap import drop_overlap_tags, drop_overlaps_for_dict
from .wd14 import get_wd14_tags
8 changes: 6 additions & 2 deletions imgutils/tagging/deepdanbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model

Expand All @@ -31,7 +32,7 @@ def _get_deepdanbooru_labels():
general_indexes = list(np.where(df["category"] == 0)[0])
character_indexes = list(np.where(df["category"] == 4)[0])
return tag_names, tag_real_names, \
rating_indexes, general_indexes, character_indexes
rating_indexes, general_indexes, character_indexes


@lru_cache()
Expand Down Expand Up @@ -61,7 +62,8 @@ def _image_preprocess(image: Image.Image) -> np.ndarray:


def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
general_threshold: float = 0.5, character_threshold: float = 0.5):
general_threshold: float = 0.5, character_threshold: float = 0.5,
drop_overlap: bool = False):
"""
Overview:
Get tags for anime image based on ``deepdanbooru`` model.
Expand Down Expand Up @@ -120,6 +122,8 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
general_names = [labels[i] for i in general_indexes]
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
if drop_overlap:
general_res = drop_overlaps_for_dict(general_res)

# Everything else is characters: pick anywhere prediction confidence > threshold
character_names = [labels[i] for i in character_indexes]
Expand Down
10 changes: 8 additions & 2 deletions imgutils/tagging/mldanbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model

Expand Down Expand Up @@ -57,7 +58,8 @@ def _get_mldanbooru_labels(use_real_name: bool = False) -> Tuple[List[str], List


def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
threshold: float = 0.7, size: int = 448, keep_ratio: bool = False):
threshold: float = 0.7, size: int = 448, keep_ratio: bool = False,
drop_overlap: bool = False):
"""
Overview:
Tagging image with ML-Danbooru, similar to
Expand Down Expand Up @@ -103,4 +105,8 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
output = (1 / (1 + np.exp(-native_output))).reshape(-1)
tags = _get_mldanbooru_labels(use_real_name)
pairs = sorted([(tags[i], ratio) for i, ratio in enumerate(output)], key=lambda x: (-x[1], x[0]))
return {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}

general_tags = {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}
if drop_overlap:
general_tags = drop_overlaps_for_dict(general_tags)
return general_tags
81 changes: 81 additions & 0 deletions imgutils/tagging/overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
from functools import lru_cache
from typing import Mapping, List

from huggingface_hub import hf_hub_download


@lru_cache()
def _get_overlap_tags() -> Mapping[str, List[str]]:
"""
Retrieve the overlap tag information from the specified Hugging Face Hub repository.
This function downloads a JSON file containing tag overlap information and parses it into a dictionary.
:return: A dictionary where keys are tags and values are lists of overlapping tags.
:rtype: Mapping[str, List[str]]
"""
json_file = hf_hub_download(
'alea31415/tag_filtering',
'overlap_tags.json',
repo_type='dataset',
)
with open(json_file, 'r') as file:
data = json.load(file)

return {
entry['query']: entry['has_overlap']
for entry in data if 'has_overlap' in entry and entry['has_overlap']
}


def drop_overlap_tags(tags: List[str]) -> List[str]:
"""
Drop overlapping tags from the given list of tags.
This function removes tags that have overlaps with other tags based on precomputed overlap information.
:param tags: A list of tags.
:type tags: List[str]
:return: A list of tags without overlaps.
:rtype: List[str]
"""
overlap_tags_dict = _get_overlap_tags()
result_tags = []
tags_underscore = [tag.replace(' ', '_') for tag in tags]

for tag, tag_ in zip(tags, tags_underscore):

to_remove = False

# Case 1: If the tag is a key and some of the associated values are in tags
if tag_ in overlap_tags_dict:
overlap_values = set(val for val in overlap_tags_dict[tag_])
if overlap_values.intersection(set(tags_underscore)):
to_remove = True

# Checking superword condition separately
for tag_another in tags:
if tag in tag_another and tag != tag_another:
to_remove = True
break

if not to_remove:
result_tags.append(tag)

return result_tags


def drop_overlaps_for_dict(tags: Mapping[str, float]) -> Mapping[str, float]:
"""
Drop overlapping tags from the given dictionary of tags with confidence scores.
This function removes tags that have overlaps with other tags based on precomputed overlap information.
:param tags: A dictionary where keys are tags and values are confidence scores.
:type tags: Mapping[str, float]
:return: A dictionary with non-overlapping tags and their corresponding confidence scores.
:rtype: Mapping[str, float]
"""
key_set = set(drop_overlap_tags(list(tags.keys())))
return {tag: confidence for tag, confidence in tags.items() if tag in key_set}
6 changes: 5 additions & 1 deletion imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import pandas as pd

from .overlap import drop_overlaps_for_dict
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model

Expand Down Expand Up @@ -83,7 +84,8 @@ def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]:


def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
general_threshold: float = 0.35, character_threshold: float = 0.85):
general_threshold: float = 0.35, character_threshold: float = 0.85,
drop_overlap: bool = False):
"""
Overview:
Tagging image by wd14 v2 model. Similar to
Expand Down Expand Up @@ -148,6 +150,8 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
general_names = [labels[i] for i in general_indexes]
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
if drop_overlap:
general_res = drop_overlaps_for_dict(general_res)

# Everything else is characters: pick anywhere prediction confidence > threshold
character_names = [labels[i] for i in character_indexes]
Expand Down
44 changes: 44 additions & 0 deletions test/tagging/test_deepdanbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,47 @@ def test_get_deepdanbooru_tags(self, _release_model_after_run):
assert tags['1girl'] >= 0.85
assert tags['ring'] > 0.8
assert chars['hu_tao_(genshin_impact)'] >= 0.7

def test_get_danbooru_tags_sample(self):
rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png'))
assert rating == pytest.approx({
'rating:safe': 8.940696716308594e-06,
'rating:questionable': 0.012878596782684326,
'rating:explicit': 0.992286205291748,
}, abs=1e-3)
assert tags == pytest.approx({
'1girl': 0.9923416376113892, 'armpits': 0.9226008653640747, 'arms_behind_head': 0.5620371699333191,
'arms_up': 0.7268614172935486, 'bangs': 0.7465004920959473, 'black_border': 0.9081975221633911,
'blush': 0.9306209683418274, 'breasts': 0.9972158670425415,
'eyebrows_visible_through_hair': 0.6717097163200378, 'hair_between_eyes': 0.7044132947921753,
'hair_intakes': 0.6295598745346069, 'horns': 0.9387356042861938, 'letterboxed': 1.0,
'long_hair': 0.9871174693107605, 'looking_at_viewer': 0.8953969478607178,
'medium_breasts': 0.90318363904953, 'navel': 0.9425054788589478, 'nipples': 0.9989081621170044,
'nude': 0.9452821016311646, 'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120401501655579,
'pussy': 0.9943721294403076, 'pussy_juice': 0.8238061666488647, 'red_hair': 0.9203640222549438,
'smile': 0.6659414172172546, 'solo': 0.9483305811882019, 'spread_legs': 0.7633067965507507,
'stomach': 0.5396291017532349, 'sweat': 0.7880321145057678, 'thighs': 0.7451953291893005,
'uncensored': 0.9594683647155762, 'very_long_hair': 0.740706205368042,
}, abs=1e-3)
assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3)

def test_get_danbooru_tags_drop_overlap(self):
rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png'), drop_overlap=True)
assert rating == pytest.approx({
'rating:safe': 8.940696716308594e-06,
'rating:questionable': 0.012878596782684326,
'rating:explicit': 0.992286205291748,
}, abs=1e-3)
assert tags == pytest.approx({
'1girl': 0.9923416376113892, 'armpits': 0.9226007461547852, 'arms_behind_head': 0.5620364546775818,
'arms_up': 0.7268615365028381, 'bangs': 0.7465004324913025, 'black_border': 0.9081975221633911,
'blush': 0.9306209683418274, 'eyebrows_visible_through_hair': 0.6717095971107483,
'hair_between_eyes': 0.7044129967689514, 'hair_intakes': 0.6295579671859741, 'horns': 0.938735842704773,
'letterboxed': 1.0, 'looking_at_viewer': 0.8953973650932312, 'medium_breasts': 0.9031840562820435,
'navel': 0.9425054788589478, 'nipples': 0.9989081621170044, 'nude': 0.9452821016311646,
'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120403289794922, 'pussy_juice': 0.8238056898117065,
'red_hair': 0.9203639030456543, 'smile': 0.6659414172172546, 'solo': 0.948330819606781,
'spread_legs': 0.7633066177368164, 'stomach': 0.5396295189857483, 'sweat': 0.7880324721336365,
'thighs': 0.745195746421814, 'uncensored': 0.9594683647155762, 'very_long_hair': 0.7407056093215942
}, abs=1e-3)
assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3)
39 changes: 39 additions & 0 deletions test/tagging/test_mldanbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,42 @@ def test_get_mldanbooru_tags(self, keep_ratio):

tags = get_mldanbooru_tags(get_testfile('6125785.jpg'), keep_ratio=keep_ratio)
assert tags['1girl'] >= 0.95

def test_get_mldanbooru_tags_sample(self):
tags = get_mldanbooru_tags(get_testfile('nude_girl.png'))
assert tags == pytest.approx({
'1girl': 0.9999977350234985, 'breasts': 0.999940037727356, 'nipples': 0.999920129776001,
'solo': 0.9993574023246765, 'pussy': 0.9993218183517456, 'horns': 0.9977452158927917,
'nude': 0.995971143245697, 'purple_eyes': 0.9957809448242188, 'long_hair': 0.9929291605949402,
'navel': 0.9814828038215637, 'armpits': 0.9808009266853333, 'spread_legs': 0.9767358303070068,
'pussy_juice': 0.959962785243988, 'blush': 0.9482676386833191, 'uncensored': 0.9446588158607483,
'looking_at_viewer': 0.9295657873153687, 'red_hair': 0.919776201248169,
'medium_breasts': 0.9020175337791443, 'completely_nude': 0.8965569138526917, 'arms_up': 0.8882529139518738,
'on_back': 0.8701885342597961, 'arms_behind_head': 0.8692260980606079, 'lying': 0.8653205037117004,
'pillow': 0.8645844459533691, 'bangs': 0.8618668913841248, 'smile': 0.8531544804573059,
'very_long_hair': 0.8332053422927856, 'pointy_ears': 0.8194612264633179, 'stomach': 0.8194073438644409,
'hair_intakes': 0.8191318511962891, 'on_bed': 0.8055890202522278, 'sweat': 0.7933878302574158,
'thighs': 0.7835342884063721, 'hair_between_eyes': 0.7693091630935669,
'eyebrows_visible_through_hair': 0.7672545313835144, 'closed_mouth': 0.7638942003250122,
'breasts_apart': 0.7527053952217102, 'bed': 0.7515304088592529, 'slit_pupils': 0.7464283108711243,
'barefoot': 0.7429600954055786, 'bed_sheet': 0.7186222076416016, 'fang': 0.7162102460861206,
'clitoris': 0.7013473510742188,
}, abs=1e-3)

def test_get_mldanbooru_tags_no_overlap(self):
tags = get_mldanbooru_tags(get_testfile('nude_girl.png'), drop_overlap=True)
assert tags == pytest.approx({
'1girl': 0.9999977350234985, 'nipples': 0.999920129776001, 'solo': 0.9993574023246765,
'horns': 0.9977452158927917, 'purple_eyes': 0.9957809448242188, 'navel': 0.9814828038215637,
'armpits': 0.9808009266853333, 'spread_legs': 0.9767358303070068, 'pussy_juice': 0.959962785243988,
'blush': 0.9482676386833191, 'uncensored': 0.9446586966514587, 'looking_at_viewer': 0.9295657873153687,
'red_hair': 0.9197760820388794, 'medium_breasts': 0.9020175337791443, 'completely_nude': 0.8965569138526917,
'arms_up': 0.8882529139518738, 'on_back': 0.8701885342597961, 'arms_behind_head': 0.8692260980606079,
'pillow': 0.8645844459533691, 'bangs': 0.8618668913841248, 'smile': 0.8531544804573059,
'very_long_hair': 0.8332052230834961, 'pointy_ears': 0.8194612264633179, 'stomach': 0.8194073438644409,
'hair_intakes': 0.8191318511962891, 'on_bed': 0.8055890202522278, 'sweat': 0.793387770652771,
'thighs': 0.7835341691970825, 'hair_between_eyes': 0.7693091034889221,
'eyebrows_visible_through_hair': 0.7672545909881592, 'closed_mouth': 0.7638942003250122,
'breasts_apart': 0.7527053356170654, 'slit_pupils': 0.7464284300804138, 'barefoot': 0.7429600358009338,
'bed_sheet': 0.7186222672462463, 'fang': 0.7162103652954102, 'clitoris': 0.7013473510742188
}, abs=1e-3)
42 changes: 42 additions & 0 deletions test/tagging/test_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from imgutils.tagging import drop_overlaps_for_dict, drop_overlap_tags


@pytest.fixture()
def complex_dict_tags():
return {
'1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507,
'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438,
'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896,
'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697,
'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648,
'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317,
'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223,
'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056,
'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686,
'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291,
'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901,
}


@pytest.mark.unittest
class TestTaggingOverlap:
def test_drop_overlap_tags(self):
assert drop_overlap_tags(['1girl', 'solo', 'long_hair', 'very_long_hair', 'red_hair']) == \
['1girl', 'solo', 'very_long_hair', 'red_hair']

def test_drop_overlaps_for_dict_complex(self, complex_dict_tags):
assert drop_overlaps_for_dict(complex_dict_tags) == pytest.approx({
'1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698,
'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605,
'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948,
'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633,
'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317,
'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658,
'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423,
'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345,
'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291,
'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727,
'clitoris': 0.5310801267623901
})
Loading

0 comments on commit 2c7cf25

Please sign in to comment.