Skip to content

Commit

Permalink
Merge pull request #43 from deepghs/dev/overlap
Browse files Browse the repository at this point in the history
dev(narugo): add overlap dropping for tags
  • Loading branch information
narugo1992 authored Oct 8, 2023
2 parents 17df867 + 5e9c1ee commit da7d591
Show file tree
Hide file tree
Showing 15 changed files with 338 additions and 19 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ jobs:
if: ${{ github.event_name == 'push' }}
env:
CI: 'true'
HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }}
HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }}
with:
shell: bash
timeout_minutes: 20
Expand Down Expand Up @@ -122,8 +120,6 @@ jobs:
uses: nick-fields/retry@v2
env:
CI: 'true'
HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }}
HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }}
with:
shell: bash
timeout_minutes: 20
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/export.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ jobs:
uses: nick-fields/retry@v2
env:
CI: 'true'
HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }}
HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }}
with:
shell: bash
timeout_minutes: 20
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ jobs:
uses: nick-fields/retry@v2
env:
CI: 'true'
HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }}
HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }}
with:
shell: bash
timeout_minutes: 20
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ pdocs:
dataset:
mkdir -p ${DATASET_DIR}
if [ ! -d ${DATASET_DIR}/chafen_arknights ]; then \
git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \
git clone https://huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \
fi
if [ ! -d ${DATASET_DIR}/monochrome_danbooru ]; then \
git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \
git clone https://huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \
fi
if [ ! -d ${DATASET_DIR}/images_test_v1 ]; then \
mkdir -p ${DATASET_DIR}/images_test_v1 && \
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_doc/tagging/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ imgutils.tagging
wd14
deepdanbooru
format
overlap
22 changes: 22 additions & 0 deletions docs/source/api_doc/tagging/overlap.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
imgutils.tagging.overlap
====================================

.. currentmodule:: imgutils.tagging.overlap

.. automodule:: imgutils.tagging.overlap


drop_overlap_tags
----------------------------------

.. autofunction:: drop_overlap_tags



drop_overlaps_for_dict
----------------------------------

.. autofunction:: drop_overlaps_for_dict



1 change: 1 addition & 0 deletions imgutils/tagging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text
from .mldanbooru import get_mldanbooru_tags
from .overlap import drop_overlap_tags, drop_overlaps_for_dict
from .wd14 import get_wd14_tags
9 changes: 7 additions & 2 deletions imgutils/tagging/deepdanbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model

Expand All @@ -31,7 +32,7 @@ def _get_deepdanbooru_labels():
general_indexes = list(np.where(df["category"] == 0)[0])
character_indexes = list(np.where(df["category"] == 4)[0])
return tag_names, tag_real_names, \
rating_indexes, general_indexes, character_indexes
rating_indexes, general_indexes, character_indexes


@lru_cache()
Expand Down Expand Up @@ -61,7 +62,8 @@ def _image_preprocess(image: Image.Image) -> np.ndarray:


def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
general_threshold: float = 0.5, character_threshold: float = 0.5):
general_threshold: float = 0.5, character_threshold: float = 0.5,
drop_overlap: bool = False):
"""
Overview:
Get tags for anime image based on ``deepdanbooru`` model.
Expand All @@ -73,6 +75,7 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
The default value of ``False`` indicates the use of the original tag names.
:param general_threshold: Threshold for default tags, default is ``0.35``.
:param character_threshold: Threshold for character tags, default is ``0.85``.
:param drop_overlap: Drop overlap tags or not, default is ``False``.
:return: Tagging results for levels, features and characters.
Example:
Expand Down Expand Up @@ -120,6 +123,8 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
general_names = [labels[i] for i in general_indexes]
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
if drop_overlap:
general_res = drop_overlaps_for_dict(general_res)

# Everything else is characters: pick anywhere prediction confidence > threshold
character_names = [labels[i] for i in character_indexes]
Expand Down
11 changes: 9 additions & 2 deletions imgutils/tagging/mldanbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model

Expand Down Expand Up @@ -57,7 +58,8 @@ def _get_mldanbooru_labels(use_real_name: bool = False) -> Tuple[List[str], List


def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
threshold: float = 0.7, size: int = 448, keep_ratio: bool = False):
threshold: float = 0.7, size: int = 448, keep_ratio: bool = False,
drop_overlap: bool = False):
"""
Overview:
Tagging image with ML-Danbooru, similar to
Expand All @@ -72,6 +74,7 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
:param size: Size when passing the resized image into model, default is ``448``.
:param keep_ratio: Keep the original ratio between height and width when passing the image into
model, default is ``False``.
:param drop_overlap: Drop overlap tags or not, default is ``False``.
Example:
Here are some images for example
Expand Down Expand Up @@ -103,4 +106,8 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
output = (1 / (1 + np.exp(-native_output))).reshape(-1)
tags = _get_mldanbooru_labels(use_real_name)
pairs = sorted([(tags[i], ratio) for i, ratio in enumerate(output)], key=lambda x: (-x[1], x[0]))
return {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}

general_tags = {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}
if drop_overlap:
general_tags = drop_overlaps_for_dict(general_tags)
return general_tags
113 changes: 113 additions & 0 deletions imgutils/tagging/overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import json
from functools import lru_cache
from typing import Mapping, List

from huggingface_hub import hf_hub_download


@lru_cache()
def _get_overlap_tags() -> Mapping[str, List[str]]:
"""
Retrieve the overlap tag information from the specified Hugging Face Hub repository.
This function downloads a JSON file containing tag overlap information and parses it into a dictionary.
:return: A dictionary where keys are tags and values are lists of overlapping tags.
:rtype: Mapping[str, List[str]]
"""
json_file = hf_hub_download(
'alea31415/tag_filtering',
'overlap_tags.json',
repo_type='dataset',
)
with open(json_file, 'r') as file:
data = json.load(file)

return {
entry['query']: entry['has_overlap']
for entry in data if 'has_overlap' in entry and entry['has_overlap']
}


def drop_overlap_tags(tags: List[str]) -> List[str]:
"""
Drop overlapping tags from the given list of tags.
This function removes tags that have overlaps with other tags based on precomputed overlap information.
:param tags: A list of tags.
:type tags: List[str]
:return: A list of tags without overlaps.
:rtype: List[str]
Examples::
>>> from imgutils.tagging import drop_overlap_tags
>>>
>>> tags = [
... '1girl', 'solo',
... 'long_hair', 'very_long_hair', 'red_hair',
... 'breasts', 'medium_breasts',
... ]
>>> drop_overlap_tags(tags)
['1girl', 'solo', 'very_long_hair', 'red_hair', 'medium_breasts']
"""
overlap_tags_dict = _get_overlap_tags()
result_tags = []
tags_underscore = [tag.replace(' ', '_') for tag in tags]

for tag, tag_ in zip(tags, tags_underscore):

to_remove = False

# Case 1: If the tag is a key and some of the associated values are in tags
if tag_ in overlap_tags_dict:
overlap_values = set(val for val in overlap_tags_dict[tag_])
if overlap_values.intersection(set(tags_underscore)):
to_remove = True

# Checking superword condition separately
for tag_another in tags:
if tag in tag_another and tag != tag_another:
to_remove = True
break

if not to_remove:
result_tags.append(tag)

return result_tags


def drop_overlaps_for_dict(tags: Mapping[str, float]) -> Mapping[str, float]:
"""
Drop overlapping tags from the given dictionary of tags with confidence scores.
This function removes tags that have overlaps with other tags based on precomputed overlap information.
:param tags: A dictionary where keys are tags and values are confidence scores.
:type tags: Mapping[str, float]
:return: A dictionary with non-overlapping tags and their corresponding confidence scores.
:rtype: Mapping[str, float]
Examples::
>>> from imgutils.tagging import drop_overlaps_for_dict
>>>
>>> tags = {
... '1girl': 0.8849405313291128,
... 'solo': 0.8548297594823425,
... 'long_hair': 0.03910296474461261,
... 'very_long_hair': 0.6615180440330748,
... 'red_hair': 0.21552028866308015,
... 'breasts': 0.3165260620737027,
... 'medium_breasts': 0.47744464927382957,
... }
>>> drop_overlaps_for_dict(tags)
{
'1girl': 0.8849405313291128,
'solo': 0.8548297594823425,
'very_long_hair': 0.6615180440330748,
'red_hair': 0.21552028866308015,
'medium_breasts': 0.47744464927382957
}
"""
key_set = set(drop_overlap_tags(list(tags.keys())))
return {tag: confidence for tag, confidence in tags.items() if tag in key_set}
7 changes: 6 additions & 1 deletion imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import pandas as pd

from .overlap import drop_overlaps_for_dict
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model

Expand Down Expand Up @@ -83,7 +84,8 @@ def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]:


def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
general_threshold: float = 0.35, character_threshold: float = 0.85):
general_threshold: float = 0.35, character_threshold: float = 0.85,
drop_overlap: bool = False):
"""
Overview:
Tagging image by wd14 v2 model. Similar to
Expand All @@ -94,6 +96,7 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
``SwinV2``, ``ConvNext``, ``ConvNextV2``, ``ViT`` or ``MOAT``, default is ``ConvNextV2``.
:param general_threshold: Threshold for default tags, default is ``0.35``.
:param character_threshold: Threshold for character tags, default is ``0.85``.
:param drop_overlap: Drop overlap tags or not, default is ``False``.
:return: Tagging results for levels, features and characters.
Example:
Expand Down Expand Up @@ -148,6 +151,8 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
general_names = [labels[i] for i in general_indexes]
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
if drop_overlap:
general_res = drop_overlaps_for_dict(general_res)

# Everything else is characters: pick anywhere prediction confidence > threshold
character_names = [labels[i] for i in character_indexes]
Expand Down
48 changes: 46 additions & 2 deletions test/tagging/test_deepdanbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from test.testings import get_testfile


@pytest.fixture()
@pytest.fixture(autouse=True, scope='module')
def _release_model_after_run():
try:
yield
Expand All @@ -15,7 +15,7 @@ def _release_model_after_run():

@pytest.mark.unittest
class TestTaggingDeepdanbooru:
def test_get_deepdanbooru_tags(self, _release_model_after_run):
def test_get_deepdanbooru_tags(self):
rating, tags, chars = get_deepdanbooru_tags(get_testfile('6124220.jpg'))
assert rating['rating:safe'] > 0.9
assert tags['greyscale'] >= 0.8
Expand All @@ -27,3 +27,47 @@ def test_get_deepdanbooru_tags(self, _release_model_after_run):
assert tags['1girl'] >= 0.85
assert tags['ring'] > 0.8
assert chars['hu_tao_(genshin_impact)'] >= 0.7

def test_get_danbooru_tags_sample(self):
rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png'))
assert rating == pytest.approx({
'rating:safe': 8.940696716308594e-06,
'rating:questionable': 0.012878596782684326,
'rating:explicit': 0.992286205291748,
}, abs=1e-3)
assert tags == pytest.approx({
'1girl': 0.9923416376113892, 'armpits': 0.9226008653640747, 'arms_behind_head': 0.5620371699333191,
'arms_up': 0.7268614172935486, 'bangs': 0.7465004920959473, 'black_border': 0.9081975221633911,
'blush': 0.9306209683418274, 'breasts': 0.9972158670425415,
'eyebrows_visible_through_hair': 0.6717097163200378, 'hair_between_eyes': 0.7044132947921753,
'hair_intakes': 0.6295598745346069, 'horns': 0.9387356042861938, 'letterboxed': 1.0,
'long_hair': 0.9871174693107605, 'looking_at_viewer': 0.8953969478607178,
'medium_breasts': 0.90318363904953, 'navel': 0.9425054788589478, 'nipples': 0.9989081621170044,
'nude': 0.9452821016311646, 'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120401501655579,
'pussy': 0.9943721294403076, 'pussy_juice': 0.8238061666488647, 'red_hair': 0.9203640222549438,
'smile': 0.6659414172172546, 'solo': 0.9483305811882019, 'spread_legs': 0.7633067965507507,
'stomach': 0.5396291017532349, 'sweat': 0.7880321145057678, 'thighs': 0.7451953291893005,
'uncensored': 0.9594683647155762, 'very_long_hair': 0.740706205368042,
}, abs=1e-3)
assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3)

def test_get_danbooru_tags_drop_overlap(self):
rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png'), drop_overlap=True)
assert rating == pytest.approx({
'rating:safe': 8.940696716308594e-06,
'rating:questionable': 0.012878596782684326,
'rating:explicit': 0.992286205291748,
}, abs=1e-3)
assert tags == pytest.approx({
'1girl': 0.9923416376113892, 'armpits': 0.9226007461547852, 'arms_behind_head': 0.5620364546775818,
'arms_up': 0.7268615365028381, 'bangs': 0.7465004324913025, 'black_border': 0.9081975221633911,
'blush': 0.9306209683418274, 'eyebrows_visible_through_hair': 0.6717095971107483,
'hair_between_eyes': 0.7044129967689514, 'hair_intakes': 0.6295579671859741, 'horns': 0.938735842704773,
'letterboxed': 1.0, 'looking_at_viewer': 0.8953973650932312, 'medium_breasts': 0.9031840562820435,
'navel': 0.9425054788589478, 'nipples': 0.9989081621170044, 'nude': 0.9452821016311646,
'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120403289794922, 'pussy_juice': 0.8238056898117065,
'red_hair': 0.9203639030456543, 'smile': 0.6659414172172546, 'solo': 0.948330819606781,
'spread_legs': 0.7633066177368164, 'stomach': 0.5396295189857483, 'sweat': 0.7880324721336365,
'thighs': 0.745195746421814, 'uncensored': 0.9594683647155762, 'very_long_hair': 0.7407056093215942
}, abs=1e-3)
assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3)
Loading

0 comments on commit da7d591

Please sign in to comment.