diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 01f7fc824cc..8c563d64bb5 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -50,8 +50,6 @@ jobs: if: ${{ github.event_name == 'push' }} env: CI: 'true' - HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }} - HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }} with: shell: bash timeout_minutes: 20 @@ -122,8 +120,6 @@ jobs: uses: nick-fields/retry@v2 env: CI: 'true' - HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }} - HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }} with: shell: bash timeout_minutes: 20 diff --git a/.github/workflows/export.yml b/.github/workflows/export.yml index c99de51d347..021c9af9173 100644 --- a/.github/workflows/export.yml +++ b/.github/workflows/export.yml @@ -44,8 +44,6 @@ jobs: uses: nick-fields/retry@v2 env: CI: 'true' - HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }} - HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }} with: shell: bash timeout_minutes: 20 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 547e8740a45..eee99a49181 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -103,8 +103,6 @@ jobs: uses: nick-fields/retry@v2 env: CI: 'true' - HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }} - HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }} with: shell: bash timeout_minutes: 20 diff --git a/Makefile b/Makefile index 9b2aa145493..a8fab6ade5c 100644 --- a/Makefile +++ b/Makefile @@ -50,10 +50,10 @@ pdocs: dataset: mkdir -p ${DATASET_DIR} if [ ! -d ${DATASET_DIR}/chafen_arknights ]; then \ - git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \ + git clone https://huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \ fi if [ ! -d ${DATASET_DIR}/monochrome_danbooru ]; then \ - git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \ + git clone https://huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \ fi if [ ! -d ${DATASET_DIR}/images_test_v1 ]; then \ mkdir -p ${DATASET_DIR}/images_test_v1 && \ diff --git a/docs/source/api_doc/tagging/index.rst b/docs/source/api_doc/tagging/index.rst index 7a67a0e87c7..989a0928135 100644 --- a/docs/source/api_doc/tagging/index.rst +++ b/docs/source/api_doc/tagging/index.rst @@ -13,3 +13,4 @@ imgutils.tagging wd14 deepdanbooru format + overlap diff --git a/docs/source/api_doc/tagging/overlap.rst b/docs/source/api_doc/tagging/overlap.rst new file mode 100644 index 00000000000..f62b0a44daf --- /dev/null +++ b/docs/source/api_doc/tagging/overlap.rst @@ -0,0 +1,22 @@ +imgutils.tagging.overlap +==================================== + +.. currentmodule:: imgutils.tagging.overlap + +.. automodule:: imgutils.tagging.overlap + + +drop_overlap_tags +---------------------------------- + +.. autofunction:: drop_overlap_tags + + + +drop_overlaps_for_dict +---------------------------------- + +.. autofunction:: drop_overlaps_for_dict + + + diff --git a/imgutils/tagging/__init__.py b/imgutils/tagging/__init__.py index ca1f028d67b..997e7e4f596 100644 --- a/imgutils/tagging/__init__.py +++ b/imgutils/tagging/__init__.py @@ -11,4 +11,5 @@ from .deepdanbooru import get_deepdanbooru_tags from .format import tags_to_text from .mldanbooru import get_mldanbooru_tags +from .overlap import drop_overlap_tags, drop_overlaps_for_dict from .wd14 import get_wd14_tags diff --git a/imgutils/tagging/deepdanbooru.py b/imgutils/tagging/deepdanbooru.py index 5c8c7a69029..afae2e4932b 100644 --- a/imgutils/tagging/deepdanbooru.py +++ b/imgutils/tagging/deepdanbooru.py @@ -16,6 +16,7 @@ from PIL import Image from huggingface_hub import hf_hub_download +from .overlap import drop_overlaps_for_dict from ..data import ImageTyping, load_image from ..utils import open_onnx_model @@ -31,7 +32,7 @@ def _get_deepdanbooru_labels(): general_indexes = list(np.where(df["category"] == 0)[0]) character_indexes = list(np.where(df["category"] == 4)[0]) return tag_names, tag_real_names, \ - rating_indexes, general_indexes, character_indexes + rating_indexes, general_indexes, character_indexes @lru_cache() @@ -61,7 +62,8 @@ def _image_preprocess(image: Image.Image) -> np.ndarray: def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False, - general_threshold: float = 0.5, character_threshold: float = 0.5): + general_threshold: float = 0.5, character_threshold: float = 0.5, + drop_overlap: bool = False): """ Overview: Get tags for anime image based on ``deepdanbooru`` model. @@ -73,6 +75,7 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False, The default value of ``False`` indicates the use of the original tag names. :param general_threshold: Threshold for default tags, default is ``0.35``. :param character_threshold: Threshold for character tags, default is ``0.85``. + :param drop_overlap: Drop overlap tags or not, default is ``False``. :return: Tagging results for levels, features and characters. Example: @@ -120,6 +123,8 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False, general_names = [labels[i] for i in general_indexes] general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) + if drop_overlap: + general_res = drop_overlaps_for_dict(general_res) # Everything else is characters: pick anywhere prediction confidence > threshold character_names = [labels[i] for i in character_indexes] diff --git a/imgutils/tagging/mldanbooru.py b/imgutils/tagging/mldanbooru.py index 7b1f28b81a6..de7741472fa 100644 --- a/imgutils/tagging/mldanbooru.py +++ b/imgutils/tagging/mldanbooru.py @@ -11,6 +11,7 @@ from PIL import Image from huggingface_hub import hf_hub_download +from .overlap import drop_overlaps_for_dict from ..data import load_image, ImageTyping from ..utils import open_onnx_model @@ -57,7 +58,8 @@ def _get_mldanbooru_labels(use_real_name: bool = False) -> Tuple[List[str], List def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False, - threshold: float = 0.7, size: int = 448, keep_ratio: bool = False): + threshold: float = 0.7, size: int = 448, keep_ratio: bool = False, + drop_overlap: bool = False): """ Overview: Tagging image with ML-Danbooru, similar to @@ -72,6 +74,7 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False, :param size: Size when passing the resized image into model, default is ``448``. :param keep_ratio: Keep the original ratio between height and width when passing the image into model, default is ``False``. + :param drop_overlap: Drop overlap tags or not, default is ``False``. Example: Here are some images for example @@ -103,4 +106,8 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False, output = (1 / (1 + np.exp(-native_output))).reshape(-1) tags = _get_mldanbooru_labels(use_real_name) pairs = sorted([(tags[i], ratio) for i, ratio in enumerate(output)], key=lambda x: (-x[1], x[0])) - return {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold} + + general_tags = {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold} + if drop_overlap: + general_tags = drop_overlaps_for_dict(general_tags) + return general_tags diff --git a/imgutils/tagging/overlap.py b/imgutils/tagging/overlap.py new file mode 100644 index 00000000000..47bddc7893c --- /dev/null +++ b/imgutils/tagging/overlap.py @@ -0,0 +1,113 @@ +import json +from functools import lru_cache +from typing import Mapping, List + +from huggingface_hub import hf_hub_download + + +@lru_cache() +def _get_overlap_tags() -> Mapping[str, List[str]]: + """ + Retrieve the overlap tag information from the specified Hugging Face Hub repository. + + This function downloads a JSON file containing tag overlap information and parses it into a dictionary. + + :return: A dictionary where keys are tags and values are lists of overlapping tags. + :rtype: Mapping[str, List[str]] + """ + json_file = hf_hub_download( + 'alea31415/tag_filtering', + 'overlap_tags.json', + repo_type='dataset', + ) + with open(json_file, 'r') as file: + data = json.load(file) + + return { + entry['query']: entry['has_overlap'] + for entry in data if 'has_overlap' in entry and entry['has_overlap'] + } + + +def drop_overlap_tags(tags: List[str]) -> List[str]: + """ + Drop overlapping tags from the given list of tags. + + This function removes tags that have overlaps with other tags based on precomputed overlap information. + + :param tags: A list of tags. + :type tags: List[str] + :return: A list of tags without overlaps. + :rtype: List[str] + + Examples:: + >>> from imgutils.tagging import drop_overlap_tags + >>> + >>> tags = [ + ... '1girl', 'solo', + ... 'long_hair', 'very_long_hair', 'red_hair', + ... 'breasts', 'medium_breasts', + ... ] + >>> drop_overlap_tags(tags) + ['1girl', 'solo', 'very_long_hair', 'red_hair', 'medium_breasts'] + """ + overlap_tags_dict = _get_overlap_tags() + result_tags = [] + tags_underscore = [tag.replace(' ', '_') for tag in tags] + + for tag, tag_ in zip(tags, tags_underscore): + + to_remove = False + + # Case 1: If the tag is a key and some of the associated values are in tags + if tag_ in overlap_tags_dict: + overlap_values = set(val for val in overlap_tags_dict[tag_]) + if overlap_values.intersection(set(tags_underscore)): + to_remove = True + + # Checking superword condition separately + for tag_another in tags: + if tag in tag_another and tag != tag_another: + to_remove = True + break + + if not to_remove: + result_tags.append(tag) + + return result_tags + + +def drop_overlaps_for_dict(tags: Mapping[str, float]) -> Mapping[str, float]: + """ + Drop overlapping tags from the given dictionary of tags with confidence scores. + + This function removes tags that have overlaps with other tags based on precomputed overlap information. + + :param tags: A dictionary where keys are tags and values are confidence scores. + :type tags: Mapping[str, float] + :return: A dictionary with non-overlapping tags and their corresponding confidence scores. + :rtype: Mapping[str, float] + + Examples:: + >>> from imgutils.tagging import drop_overlaps_for_dict + >>> + >>> tags = { + ... '1girl': 0.8849405313291128, + ... 'solo': 0.8548297594823425, + ... 'long_hair': 0.03910296474461261, + ... 'very_long_hair': 0.6615180440330748, + ... 'red_hair': 0.21552028866308015, + ... 'breasts': 0.3165260620737027, + ... 'medium_breasts': 0.47744464927382957, + ... } + >>> drop_overlaps_for_dict(tags) + { + '1girl': 0.8849405313291128, + 'solo': 0.8548297594823425, + 'very_long_hair': 0.6615180440330748, + 'red_hair': 0.21552028866308015, + 'medium_breasts': 0.47744464927382957 + } + """ + key_set = set(drop_overlap_tags(list(tags.keys()))) + return {tag: confidence for tag, confidence in tags.items() if tag in key_set} diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index c24f16e1326..63165919d33 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -11,6 +11,7 @@ import numpy as np import pandas as pd +from .overlap import drop_overlaps_for_dict from ..data import load_image, ImageTyping from ..utils import open_onnx_model @@ -83,7 +84,8 @@ def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]: def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", - general_threshold: float = 0.35, character_threshold: float = 0.85): + general_threshold: float = 0.35, character_threshold: float = 0.85, + drop_overlap: bool = False): """ Overview: Tagging image by wd14 v2 model. Similar to @@ -94,6 +96,7 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", ``SwinV2``, ``ConvNext``, ``ConvNextV2``, ``ViT`` or ``MOAT``, default is ``ConvNextV2``. :param general_threshold: Threshold for default tags, default is ``0.35``. :param character_threshold: Threshold for character tags, default is ``0.85``. + :param drop_overlap: Drop overlap tags or not, default is ``False``. :return: Tagging results for levels, features and characters. Example: @@ -148,6 +151,8 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", general_names = [labels[i] for i in general_indexes] general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) + if drop_overlap: + general_res = drop_overlaps_for_dict(general_res) # Everything else is characters: pick anywhere prediction confidence > threshold character_names = [labels[i] for i in character_indexes] diff --git a/test/tagging/test_deepdanbooru.py b/test/tagging/test_deepdanbooru.py index 168448f88f8..eb53be0fc2c 100644 --- a/test/tagging/test_deepdanbooru.py +++ b/test/tagging/test_deepdanbooru.py @@ -5,7 +5,7 @@ from test.testings import get_testfile -@pytest.fixture() +@pytest.fixture(autouse=True, scope='module') def _release_model_after_run(): try: yield @@ -15,7 +15,7 @@ def _release_model_after_run(): @pytest.mark.unittest class TestTaggingDeepdanbooru: - def test_get_deepdanbooru_tags(self, _release_model_after_run): + def test_get_deepdanbooru_tags(self): rating, tags, chars = get_deepdanbooru_tags(get_testfile('6124220.jpg')) assert rating['rating:safe'] > 0.9 assert tags['greyscale'] >= 0.8 @@ -27,3 +27,47 @@ def test_get_deepdanbooru_tags(self, _release_model_after_run): assert tags['1girl'] >= 0.85 assert tags['ring'] > 0.8 assert chars['hu_tao_(genshin_impact)'] >= 0.7 + + def test_get_danbooru_tags_sample(self): + rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png')) + assert rating == pytest.approx({ + 'rating:safe': 8.940696716308594e-06, + 'rating:questionable': 0.012878596782684326, + 'rating:explicit': 0.992286205291748, + }, abs=1e-3) + assert tags == pytest.approx({ + '1girl': 0.9923416376113892, 'armpits': 0.9226008653640747, 'arms_behind_head': 0.5620371699333191, + 'arms_up': 0.7268614172935486, 'bangs': 0.7465004920959473, 'black_border': 0.9081975221633911, + 'blush': 0.9306209683418274, 'breasts': 0.9972158670425415, + 'eyebrows_visible_through_hair': 0.6717097163200378, 'hair_between_eyes': 0.7044132947921753, + 'hair_intakes': 0.6295598745346069, 'horns': 0.9387356042861938, 'letterboxed': 1.0, + 'long_hair': 0.9871174693107605, 'looking_at_viewer': 0.8953969478607178, + 'medium_breasts': 0.90318363904953, 'navel': 0.9425054788589478, 'nipples': 0.9989081621170044, + 'nude': 0.9452821016311646, 'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120401501655579, + 'pussy': 0.9943721294403076, 'pussy_juice': 0.8238061666488647, 'red_hair': 0.9203640222549438, + 'smile': 0.6659414172172546, 'solo': 0.9483305811882019, 'spread_legs': 0.7633067965507507, + 'stomach': 0.5396291017532349, 'sweat': 0.7880321145057678, 'thighs': 0.7451953291893005, + 'uncensored': 0.9594683647155762, 'very_long_hair': 0.740706205368042, + }, abs=1e-3) + assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3) + + def test_get_danbooru_tags_drop_overlap(self): + rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png'), drop_overlap=True) + assert rating == pytest.approx({ + 'rating:safe': 8.940696716308594e-06, + 'rating:questionable': 0.012878596782684326, + 'rating:explicit': 0.992286205291748, + }, abs=1e-3) + assert tags == pytest.approx({ + '1girl': 0.9923416376113892, 'armpits': 0.9226007461547852, 'arms_behind_head': 0.5620364546775818, + 'arms_up': 0.7268615365028381, 'bangs': 0.7465004324913025, 'black_border': 0.9081975221633911, + 'blush': 0.9306209683418274, 'eyebrows_visible_through_hair': 0.6717095971107483, + 'hair_between_eyes': 0.7044129967689514, 'hair_intakes': 0.6295579671859741, 'horns': 0.938735842704773, + 'letterboxed': 1.0, 'looking_at_viewer': 0.8953973650932312, 'medium_breasts': 0.9031840562820435, + 'navel': 0.9425054788589478, 'nipples': 0.9989081621170044, 'nude': 0.9452821016311646, + 'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120403289794922, 'pussy_juice': 0.8238056898117065, + 'red_hair': 0.9203639030456543, 'smile': 0.6659414172172546, 'solo': 0.948330819606781, + 'spread_legs': 0.7633066177368164, 'stomach': 0.5396295189857483, 'sweat': 0.7880324721336365, + 'thighs': 0.745195746421814, 'uncensored': 0.9594683647155762, 'very_long_hair': 0.7407056093215942 + }, abs=1e-3) + assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3) diff --git a/test/tagging/test_mldanbooru.py b/test/tagging/test_mldanbooru.py index c1331aa9470..ecfcde61015 100644 --- a/test/tagging/test_mldanbooru.py +++ b/test/tagging/test_mldanbooru.py @@ -5,7 +5,7 @@ from test.testings import get_testfile -@pytest.fixture() +@pytest.fixture(autouse=True, scope='module') def _release_model_after_run(): try: yield @@ -22,3 +22,42 @@ def test_get_mldanbooru_tags(self, keep_ratio): tags = get_mldanbooru_tags(get_testfile('6125785.jpg'), keep_ratio=keep_ratio) assert tags['1girl'] >= 0.95 + + def test_get_mldanbooru_tags_sample(self): + tags = get_mldanbooru_tags(get_testfile('nude_girl.png')) + assert tags == pytest.approx({ + '1girl': 0.9999977350234985, 'breasts': 0.999940037727356, 'nipples': 0.999920129776001, + 'solo': 0.9993574023246765, 'pussy': 0.9993218183517456, 'horns': 0.9977452158927917, + 'nude': 0.995971143245697, 'purple_eyes': 0.9957809448242188, 'long_hair': 0.9929291605949402, + 'navel': 0.9814828038215637, 'armpits': 0.9808009266853333, 'spread_legs': 0.9767358303070068, + 'pussy_juice': 0.959962785243988, 'blush': 0.9482676386833191, 'uncensored': 0.9446588158607483, + 'looking_at_viewer': 0.9295657873153687, 'red_hair': 0.919776201248169, + 'medium_breasts': 0.9020175337791443, 'completely_nude': 0.8965569138526917, 'arms_up': 0.8882529139518738, + 'on_back': 0.8701885342597961, 'arms_behind_head': 0.8692260980606079, 'lying': 0.8653205037117004, + 'pillow': 0.8645844459533691, 'bangs': 0.8618668913841248, 'smile': 0.8531544804573059, + 'very_long_hair': 0.8332053422927856, 'pointy_ears': 0.8194612264633179, 'stomach': 0.8194073438644409, + 'hair_intakes': 0.8191318511962891, 'on_bed': 0.8055890202522278, 'sweat': 0.7933878302574158, + 'thighs': 0.7835342884063721, 'hair_between_eyes': 0.7693091630935669, + 'eyebrows_visible_through_hair': 0.7672545313835144, 'closed_mouth': 0.7638942003250122, + 'breasts_apart': 0.7527053952217102, 'bed': 0.7515304088592529, 'slit_pupils': 0.7464283108711243, + 'barefoot': 0.7429600954055786, 'bed_sheet': 0.7186222076416016, 'fang': 0.7162102460861206, + 'clitoris': 0.7013473510742188, + }, abs=1e-3) + + def test_get_mldanbooru_tags_no_overlap(self): + tags = get_mldanbooru_tags(get_testfile('nude_girl.png'), drop_overlap=True) + assert tags == pytest.approx({ + '1girl': 0.9999977350234985, 'nipples': 0.999920129776001, 'solo': 0.9993574023246765, + 'horns': 0.9977452158927917, 'purple_eyes': 0.9957809448242188, 'navel': 0.9814828038215637, + 'armpits': 0.9808009266853333, 'spread_legs': 0.9767358303070068, 'pussy_juice': 0.959962785243988, + 'blush': 0.9482676386833191, 'uncensored': 0.9446586966514587, 'looking_at_viewer': 0.9295657873153687, + 'red_hair': 0.9197760820388794, 'medium_breasts': 0.9020175337791443, 'completely_nude': 0.8965569138526917, + 'arms_up': 0.8882529139518738, 'on_back': 0.8701885342597961, 'arms_behind_head': 0.8692260980606079, + 'pillow': 0.8645844459533691, 'bangs': 0.8618668913841248, 'smile': 0.8531544804573059, + 'very_long_hair': 0.8332052230834961, 'pointy_ears': 0.8194612264633179, 'stomach': 0.8194073438644409, + 'hair_intakes': 0.8191318511962891, 'on_bed': 0.8055890202522278, 'sweat': 0.793387770652771, + 'thighs': 0.7835341691970825, 'hair_between_eyes': 0.7693091034889221, + 'eyebrows_visible_through_hair': 0.7672545909881592, 'closed_mouth': 0.7638942003250122, + 'breasts_apart': 0.7527053356170654, 'slit_pupils': 0.7464284300804138, 'barefoot': 0.7429600358009338, + 'bed_sheet': 0.7186222672462463, 'fang': 0.7162103652954102, 'clitoris': 0.7013473510742188 + }, abs=1e-3) diff --git a/test/tagging/test_overlap.py b/test/tagging/test_overlap.py new file mode 100644 index 00000000000..6535fc7e840 --- /dev/null +++ b/test/tagging/test_overlap.py @@ -0,0 +1,42 @@ +import pytest + +from imgutils.tagging import drop_overlaps_for_dict, drop_overlap_tags + + +@pytest.fixture() +def complex_dict_tags(): + return { + '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507, + 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, + 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, + 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, + 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, + 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317, + 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223, + 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, + 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, + 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, + 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901, + } + + +@pytest.mark.unittest +class TestTaggingOverlap: + def test_drop_overlap_tags(self): + assert drop_overlap_tags(['1girl', 'solo', 'long_hair', 'very_long_hair', 'red_hair']) == \ + ['1girl', 'solo', 'very_long_hair', 'red_hair'] + + def test_drop_overlaps_for_dict_complex(self, complex_dict_tags): + assert drop_overlaps_for_dict(complex_dict_tags) == pytest.approx({ + '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698, + 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, + 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, + 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, + 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317, + 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658, + 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, + 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, + 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, + 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, + 'clitoris': 0.5310801267623901 + }) diff --git a/test/tagging/test_wd14.py b/test/tagging/test_wd14.py index 3318795a64b..897cf9e8729 100644 --- a/test/tagging/test_wd14.py +++ b/test/tagging/test_wd14.py @@ -5,7 +5,7 @@ from test.testings import get_testfile -@pytest.fixture() +@pytest.fixture(autouse=True, scope='module') def _release_model_after_run(): try: yield @@ -26,3 +26,51 @@ def test_get_wd14_tags(self): assert 0.35 <= rating['sensitive'] <= 0.45 assert tags['1girl'] >= 0.95 assert chars['hu_tao_(genshin_impact)'] >= 0.95 + + def test_wd14_tags_sample(self): + rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png')) + assert rating == pytest.approx({ + 'general': 0.0020540356636047363, + 'sensitive': 0.0080718994140625, + 'questionable': 0.003170192241668701, + 'explicit': 0.984081506729126, + }, abs=1e-3) + assert tags == pytest.approx({ + '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507, + 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, + 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, + 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, + 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, + 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317, + 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223, + 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, + 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, + 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, + 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, + 'clitoris': 0.5310801267623901 + }, abs=1e-3) + assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) + + def test_wd14_tags_no_overlap(self): + rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), drop_overlap=True) + # print(tags) + assert rating == pytest.approx({ + 'general': 0.0020540356636047363, + 'sensitive': 0.0080718994140625, + 'questionable': 0.003170192241668701, + 'explicit': 0.984081506729126, + }, abs=1e-3) + assert tags == pytest.approx({ + '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698, + 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, + 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, + 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, + 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317, + 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658, + 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, + 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, + 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, + 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, + 'clitoris': 0.5310801267623901 + }, abs=1e-3) + assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3)