diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 01f7fc824cc..8c563d64bb5 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -50,8 +50,6 @@ jobs: if: ${{ github.event_name == 'push' }} env: CI: 'true' - HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }} - HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }} with: shell: bash timeout_minutes: 20 @@ -122,8 +120,6 @@ jobs: uses: nick-fields/retry@v2 env: CI: 'true' - HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }} - HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }} with: shell: bash timeout_minutes: 20 diff --git a/.github/workflows/export.yml b/.github/workflows/export.yml index c99de51d347..021c9af9173 100644 --- a/.github/workflows/export.yml +++ b/.github/workflows/export.yml @@ -44,8 +44,6 @@ jobs: uses: nick-fields/retry@v2 env: CI: 'true' - HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }} - HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }} with: shell: bash timeout_minutes: 20 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 547e8740a45..eee99a49181 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -103,8 +103,6 @@ jobs: uses: nick-fields/retry@v2 env: CI: 'true' - HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }} - HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }} with: shell: bash timeout_minutes: 20 diff --git a/Makefile b/Makefile index 9b2aa145493..a8fab6ade5c 100644 --- a/Makefile +++ b/Makefile @@ -50,10 +50,10 @@ pdocs: dataset: mkdir -p ${DATASET_DIR} if [ ! -d ${DATASET_DIR}/chafen_arknights ]; then \ - git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \ + git clone https://huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \ fi if [ ! -d ${DATASET_DIR}/monochrome_danbooru ]; then \ - git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \ + git clone https://huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \ fi if [ ! -d ${DATASET_DIR}/images_test_v1 ]; then \ mkdir -p ${DATASET_DIR}/images_test_v1 && \ diff --git a/docs/source/_libs/plot.py b/docs/source/_libs/plot.py index 87bcd71ebc3..8afe3ff135f 100644 --- a/docs/source/_libs/plot.py +++ b/docs/source/_libs/plot.py @@ -1,6 +1,7 @@ from typing import Tuple, List import matplotlib.pyplot as plt +import numpy as np from PIL import Image from cli import _wrap_func_as_cli @@ -44,6 +45,14 @@ def image_plot(*images, save_as: str, columns=2, keep_axis: bool = False, figsiz n = len(images) rows = (n + columns - 1) // columns fig, axs = plt.subplots(rows, columns, figsize=figsize) + if rows == 1 and columns == 1: + axs = np.array([[axs]]) + elif rows == 1: + axs = axs[None, ...] + elif columns == 1: + axs = axs[..., None] + else: + pass plt.subplots_adjust(wspace=0.2, hspace=0.15) for i, img in enumerate(images, start=0): xi, yi = i // columns, i % columns diff --git a/docs/source/api_doc/detect/eye.rst b/docs/source/api_doc/detect/eye.rst new file mode 100644 index 00000000000..98f39be8362 --- /dev/null +++ b/docs/source/api_doc/detect/eye.rst @@ -0,0 +1,14 @@ +imgutils.detect.eye +========================== + +.. currentmodule:: imgutils.detect.eye + +.. automodule:: imgutils.detect.eye + + +detect_eyes +------------------------------ + +.. autofunction:: detect_eyes + + diff --git a/docs/source/api_doc/detect/eye_detect_benchmark.plot.py b/docs/source/api_doc/detect/eye_detect_benchmark.plot.py new file mode 100644 index 00000000000..f4c31266ad7 --- /dev/null +++ b/docs/source/api_doc/detect/eye_detect_benchmark.plot.py @@ -0,0 +1,37 @@ +import random + +from benchmark import BaseBenchmark, create_plot_cli +from imgutils.detect import detect_eyes + + +class EyeDetectBenchmark(BaseBenchmark): + def __init__(self, level, version): + BaseBenchmark.__init__(self) + self.level = level + self.version = version + + def load(self): + from imgutils.detect.eye import _open_eye_detect_model + _ = _open_eye_detect_model(level=self.level, version=self.version) + + def unload(self): + from imgutils.detect.eye import _open_eye_detect_model + _open_eye_detect_model.cache_clear() + + def run(self): + image_file = random.choice(self.all_images) + _ = detect_eyes(image_file, level=self.level, version=self.version) + + +if __name__ == '__main__': + create_plot_cli( + [ + ('eye v1.0 (yolov8s)', EyeDetectBenchmark('s', 'v1.0')), + ('eye v1.0 (yolov8n)', EyeDetectBenchmark('n', 'v1.0')), + ('eye v0.8 (yolov8s)', EyeDetectBenchmark('s', 'v0.8')), + ('eye v0.7 (yolov8s)', EyeDetectBenchmark('s', 'v0.7')), + ], + title='Benchmark for Anime Eyes Detections', + run_times=10, + try_times=20, + )() diff --git a/docs/source/api_doc/detect/eye_detect_benchmark.plot.py.svg b/docs/source/api_doc/detect/eye_detect_benchmark.plot.py.svg new file mode 100644 index 00000000000..1679499c798 --- /dev/null +++ b/docs/source/api_doc/detect/eye_detect_benchmark.plot.py.svg @@ -0,0 +1,2546 @@ + + + + + + + + 2023-09-26T04:37:15.815012 + image/svg+xml + + + Matplotlib v3.7.3, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/api_doc/detect/eye_detect_demo.plot.py b/docs/source/api_doc/detect/eye_detect_demo.plot.py new file mode 100644 index 00000000000..675eac261cf --- /dev/null +++ b/docs/source/api_doc/detect/eye_detect_demo.plot.py @@ -0,0 +1,19 @@ +from imgutils.detect import detect_eyes +from imgutils.detect.eye import _LABELS +from imgutils.detect.visual import detection_visualize +from plot import image_plot + + +def _detect(img, **kwargs): + return detection_visualize(img, detect_eyes(img, **kwargs), _LABELS) + + +if __name__ == '__main__': + image_plot( + (_detect('nian.png'), 'large scale'), + (_detect('two_bikini_girls.png'), 'closed heads'), + (_detect('halfbody/squat.jpg'), 'pose'), + (_detect('mostima_post.jpg'), 'multiple'), + columns=2, + figsize=(10, 9), + ) diff --git a/docs/source/api_doc/detect/eye_detect_demo.plot.py.svg b/docs/source/api_doc/detect/eye_detect_demo.plot.py.svg new file mode 100644 index 00000000000..67ec6b3cf6e --- /dev/null +++ b/docs/source/api_doc/detect/eye_detect_demo.plot.py.svg @@ -0,0 +1,481 @@ + + + + + + + + 2023-09-26T04:51:53.668506 + image/svg+xml + + + Matplotlib v3.7.3, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/api_doc/detect/index.rst b/docs/source/api_doc/detect/index.rst index 4e8c53e77c5..48a1360957d 100644 --- a/docs/source/api_doc/detect/index.rst +++ b/docs/source/api_doc/detect/index.rst @@ -10,10 +10,12 @@ imgutils.detect :maxdepth: 3 censor + eye face halfbody hand head person + text visual diff --git a/docs/source/api_doc/detect/text.rst b/docs/source/api_doc/detect/text.rst new file mode 100644 index 00000000000..dd05bf7e97d --- /dev/null +++ b/docs/source/api_doc/detect/text.rst @@ -0,0 +1,14 @@ +imgutils.detect.text +========================== + +.. currentmodule:: imgutils.detect.text + +.. automodule:: imgutils.detect.text + + +detect_text +------------------------------ + +.. autofunction:: detect_text + + diff --git a/docs/source/api_doc/detect/text/ml1.png b/docs/source/api_doc/detect/text/ml1.png new file mode 100644 index 00000000000..7fbd587cecf Binary files /dev/null and b/docs/source/api_doc/detect/text/ml1.png differ diff --git a/docs/source/api_doc/detect/text/ml2.jpg b/docs/source/api_doc/detect/text/ml2.jpg new file mode 100644 index 00000000000..4ad18461bd3 Binary files /dev/null and b/docs/source/api_doc/detect/text/ml2.jpg differ diff --git a/docs/source/api_doc/detect/text_detect_benchmark.plot.py b/docs/source/api_doc/detect/text_detect_benchmark.plot.py new file mode 100644 index 00000000000..8866f1159c7 --- /dev/null +++ b/docs/source/api_doc/detect/text_detect_benchmark.plot.py @@ -0,0 +1,53 @@ +import random + +from benchmark import BaseBenchmark, create_plot_cli + +from imgutils.detect import detect_text + + +class TextDetectBenchmark(BaseBenchmark): + def __init__(self, model): + BaseBenchmark.__init__(self) + self.model = model + + def load(self): + from imgutils.detect.text import _open_text_detect_model + _ = _open_text_detect_model(self.model) + + def unload(self): + from imgutils.detect.text import _open_text_detect_model + _open_text_detect_model.cache_clear() + + def run(self): + image_file = random.choice(self.all_images) + _ = detect_text(image_file, model=self.model) + + +if __name__ == '__main__': + create_plot_cli( + [ + ( + 'dbnet_resnet18_fpnc_1200e_icdar2015', + TextDetectBenchmark('dbnet_resnet18_fpnc_1200e_icdar2015') + ), + ( + 'dbnet_resnet18_fpnc_1200e_totaltext', + TextDetectBenchmark('dbnet_resnet18_fpnc_1200e_totaltext') + ), + ( + 'dbnet_resnet50-oclip_fpnc_1200e_icdar2015', + TextDetectBenchmark('dbnet_resnet50-oclip_fpnc_1200e_icdar2015') + ), + ( + 'dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015', + TextDetectBenchmark('dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015') + ), + ( + 'dbnetpp_resnet50_fpnc_1200e_icdar2015', + TextDetectBenchmark('dbnetpp_resnet50_fpnc_1200e_icdar2015') + ), + ], + title='Benchmark for Text Detections', + run_times=10, + try_times=20, + )() diff --git a/docs/source/api_doc/detect/text_detect_benchmark.plot.py.svg b/docs/source/api_doc/detect/text_detect_benchmark.plot.py.svg new file mode 100644 index 00000000000..7bb3cc82577 --- /dev/null +++ b/docs/source/api_doc/detect/text_detect_benchmark.plot.py.svg @@ -0,0 +1,2804 @@ + + + + + + + + 2023-10-08T17:36:08.202133 + image/svg+xml + + + Matplotlib v3.7.3, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/api_doc/detect/text_detect_demo.plot.py b/docs/source/api_doc/detect/text_detect_demo.plot.py new file mode 100644 index 00000000000..6be505b8f5c --- /dev/null +++ b/docs/source/api_doc/detect/text_detect_demo.plot.py @@ -0,0 +1,17 @@ +from plot import image_plot + +from imgutils.detect import detect_text +from imgutils.detect.visual import detection_visualize + + +def _detect(img, **kwargs): + return detection_visualize(img, detect_text(img, **kwargs)) + + +if __name__ == '__main__': + image_plot( + (_detect('text/ml1.png'), 'Multiple Languages I'), + (_detect('text/ml2.jpg'), 'Multiple Languages II'), + columns=1, + figsize=(8, 9), + ) diff --git a/docs/source/api_doc/detect/text_detect_demo.plot.py.svg b/docs/source/api_doc/detect/text_detect_demo.plot.py.svg new file mode 100644 index 00000000000..093304b2704 --- /dev/null +++ b/docs/source/api_doc/detect/text_detect_demo.plot.py.svg @@ -0,0 +1,370 @@ + + + + + + + + 2023-10-08T17:00:36.270011 + image/svg+xml + + + Matplotlib v3.7.3, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/api_doc/tagging/index.rst b/docs/source/api_doc/tagging/index.rst index 7a67a0e87c7..989a0928135 100644 --- a/docs/source/api_doc/tagging/index.rst +++ b/docs/source/api_doc/tagging/index.rst @@ -13,3 +13,4 @@ imgutils.tagging wd14 deepdanbooru format + overlap diff --git a/docs/source/api_doc/tagging/overlap.rst b/docs/source/api_doc/tagging/overlap.rst new file mode 100644 index 00000000000..f62b0a44daf --- /dev/null +++ b/docs/source/api_doc/tagging/overlap.rst @@ -0,0 +1,22 @@ +imgutils.tagging.overlap +==================================== + +.. currentmodule:: imgutils.tagging.overlap + +.. automodule:: imgutils.tagging.overlap + + +drop_overlap_tags +---------------------------------- + +.. autofunction:: drop_overlap_tags + + + +drop_overlaps_for_dict +---------------------------------- + +.. autofunction:: drop_overlaps_for_dict + + + diff --git a/imgutils/config/meta.py b/imgutils/config/meta.py index db12b9657c2..2526cbe12c5 100644 --- a/imgutils/config/meta.py +++ b/imgutils/config/meta.py @@ -7,7 +7,7 @@ __TITLE__ = 'imgutils' #: Version of this project. -__VERSION__ = '0.2.5' +__VERSION__ = '0.2.7' #: Short description of the project, will be included in ``setup.py``. __DESCRIPTION__ = 'A convenient and user-friendly anime-style image data processing library that integrates ' \ diff --git a/imgutils/detect/__init__.py b/imgutils/detect/__init__.py index fe37f41116b..2252e8641c8 100644 --- a/imgutils/detect/__init__.py +++ b/imgutils/detect/__init__.py @@ -9,9 +9,11 @@ :align: center """ from .censor import detect_censors +from .eye import detect_eyes from .face import detect_faces from .halfbody import detect_halfbody from .hand import detect_hands from .head import detect_heads from .person import detect_person +from .text import detect_text from .visual import detection_visualize diff --git a/imgutils/detect/eye.py b/imgutils/detect/eye.py new file mode 100644 index 00000000000..ace44fc87b0 --- /dev/null +++ b/imgutils/detect/eye.py @@ -0,0 +1,76 @@ +""" +Overview: + Detect eyes in anime images. + + Trained on dataset `deepghs/anime_eye_detection `_ with YOLOv8. + + .. image:: eye_detect_demo.plot.py.svg + :align: center + + This is an overall benchmark of all the eye detect models: + + .. image:: eye_detect_benchmark.plot.py.svg + :align: center + +""" +from functools import lru_cache +from typing import List, Tuple + +from huggingface_hub import hf_hub_download + +from ._yolo import _image_preprocess, _data_postprocess +from ..data import ImageTyping, load_image, rgb_encode +from ..utils import open_onnx_model + + +@lru_cache() +def _open_eye_detect_model(level: str = 's', version: str = 'v1.0'): + return open_onnx_model(hf_hub_download( + f'deepghs/anime_eye_detection', + f'eye_detect_{version}_{level}/model.onnx' + )) + + +_LABELS = ["eye"] + + +def detect_eyes(image: ImageTyping, level: str = 's', version: str = 'v1.0', max_infer_size=640, + conf_threshold: float = 0.3, iou_threshold: float = 0.3) \ + -> List[Tuple[Tuple[int, int, int, int], str, float]]: + """ + Overview: + Detect human eyes in anime images. + + :param image: Image to detect. + :param level: The model level being used can be either `s` or `n`. + The `n` model runs faster with smaller system overhead, while the `s` model achieves higher accuracy. + The default value is `s`. + :param version: Version of model, default is ``v1.0``. + :param max_infer_size: The maximum image size used for model inference, if the image size exceeds this limit, + the image will be resized and used for inference. The default value is `640` pixels. + :param conf_threshold: The confidence threshold, only detection results with confidence scores above + this threshold will be returned. The default value is `0.3`. + :param iou_threshold: The detection area coverage overlap threshold, areas with overlaps above this threshold + will be discarded. The default value is `0.3`. + :return: The detection results list, each item includes the detected area `(x0, y0, x1, y1)`, + the target type (always `eye`) and the target confidence score. + + Examples:: + >>> from imgutils.detect import detect_eyes, detection_visualize + >>> + >>> image = 'squat.jpg' + >>> result = detect_eyes(image) # detect it + >>> result + [((297, 239, 341, 271), 'eye', 0.7760562896728516), ((230, 289, 263, 308), 'eye', 0.7682342529296875)] + >>> + >>> # visualize it + >>> from matplotlib import pyplot as plt + >>> plt.imshow(detection_visualize(image, result)) + >>> plt.show() + """ + image = load_image(image, mode='RGB') + new_image, old_size, new_size = _image_preprocess(image, max_infer_size) + + data = rgb_encode(new_image)[None, ...] + output, = _open_eye_detect_model(level, version).run(['output0'], {'images': data}) + return _data_postprocess(output[0], conf_threshold, iou_threshold, old_size, new_size, _LABELS) diff --git a/imgutils/detect/text.py b/imgutils/detect/text.py new file mode 100644 index 00000000000..45518478c01 --- /dev/null +++ b/imgutils/detect/text.py @@ -0,0 +1,141 @@ +""" +Overview: + Detect text in images. + + Models are hosted on `deepghs/text_detection `_. + + .. image:: text_detect_demo.plot.py.svg + :align: center + + This is an overall benchmark of all the text detect models: + + .. image:: text_detect_benchmark.plot.py.svg + :align: center + +""" +from functools import lru_cache +from typing import List, Tuple, Optional + +import cv2 +import numpy as np +from huggingface_hub import hf_hub_download + +from ..data import ImageTyping, load_image +from ..utils import open_onnx_model + +_DEFAULT_MODEL = 'dbnetpp_resnet50_fpnc_1200e_icdar2015' + + +@lru_cache() +def _open_text_detect_model(model: str): + """ + Get an ONNX session for the specified DBNET or DBNET++ model. + + This function downloads the ONNX model and opens it using the imgutils library. + + :param model: Model name for DBNET or DBNET++. + :type model: str + :return: ONNX session for the specified model. + """ + return open_onnx_model(hf_hub_download( + 'deepghs/text_detection', + f'{model}/end2end.onnx' + )) + + +def _get_heatmap_of_text(image: ImageTyping, model: str) -> np.ndarray: + """ + Get the heatmap of text regions from the given image using the specified model. + + :param image: Input image. + :type image: ImageTyping + :param model: Model name for DBNET or DBNET++. + :type model: str + :return: Heatmap of text regions. + :rtype: np.ndarray + """ + origin_width, origin_height = width, height = image.size + align = 32 + if width % align != 0: + width += (align - width % align) + if height % align != 0: + height += (align - height % align) + + input_ = np.array(image).transpose((2, 0, 1)).astype(np.float32) / 255.0 + input_ = np.pad(input_[None, ...], ((0, 0), (0, 0), (0, height - origin_height), (0, width - origin_width))) + + def _normalize(data, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)): + mean, std = np.asarray(mean), np.asarray(std) + return (data - mean[None, :, None, None]) / std[None, :, None, None] + + ort = _open_text_detect_model(model) + + input_ = _normalize(input_).astype(np.float32) + output_, = ort.run(['output'], {'input': input_}) + heatmap = output_[0] + heatmap = heatmap[:origin_height, :origin_width] + + return heatmap + + +def _get_bounding_box_of_text(image: ImageTyping, model: str, threshold: float) \ + -> List[Tuple[Tuple[int, int, int, int], float]]: + """ + Get bounding boxes of detected text regions from the given image using the specified model and threshold. + + :param image: Input image. + :type image: ImageTyping + :param model: Model name for DBNET or DBNET++. + :type model: str + :param threshold: Confidence threshold for text detection. + :type threshold: float + :return: List of bounding boxes and their scores. + :rtype: List[Tuple[Tuple[int, int, int, int], float]] + """ + heatmap = _get_heatmap_of_text(image, model) + c_rets = cv2.findContours((heatmap * 255.0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours = c_rets[0] if len(c_rets) == 2 else c_rets[1] + bboxes = [] + for c in contours: + x, y, w, h = cv2.boundingRect(c) + x0, y0, x1, y1 = x, y, x + w, y + h + score = heatmap[y0:y1, x0:x1].mean().item() + if score >= threshold: + bboxes.append(((x0, y0, x1, y1), score)) + + return bboxes + + +def detect_text(image: ImageTyping, model: str = _DEFAULT_MODEL, threshold: float = 0.05, + max_area_size: Optional[int] = 640): + """ + Detect text regions in the given image using the specified model and threshold. + + :param image: Input image. + :type image: ImageTyping + :param model: Model name for DBNET or DBNET++. + :type model: str + :param threshold: Confidence threshold for text detection. + :type threshold: float + :param max_area_size: Max area size when doing inference. Default is ``640``, which means if + the image's area is over 640x640, it will be resized. When assigned to ``None``, + it means do not resize in any case. + :type max_area_size: Optional[int] + :return: List of detected text bounding boxes, labels, and scores. + :rtype: List[Tuple[Tuple[int, int, int, int], str, float]] + """ + image = load_image(image) + if max_area_size is not None and image.width * image.height >= max_area_size ** 2: + r = ((image.width * image.height) / (max_area_size ** 2)) ** 0.5 + new_width, new_height = int(image.width / r), int(image.height / r) + image = image.resize((new_width, new_height)) + else: + r = 1.0 + + bboxes = [] + for (x0, y0, x1, y1), score in _get_bounding_box_of_text(image, model, threshold): + x0, y0, x1, y1 = int(x0 * r), int(y0 * r), int(x1 * r), int(y1 * r) + bboxes.append(((x0, y0, x1, y1), 'text', score)) + + bboxes = sorted(bboxes, key=lambda x: x[2], reverse=True) + return bboxes diff --git a/imgutils/tagging/__init__.py b/imgutils/tagging/__init__.py index ca1f028d67b..997e7e4f596 100644 --- a/imgutils/tagging/__init__.py +++ b/imgutils/tagging/__init__.py @@ -11,4 +11,5 @@ from .deepdanbooru import get_deepdanbooru_tags from .format import tags_to_text from .mldanbooru import get_mldanbooru_tags +from .overlap import drop_overlap_tags, drop_overlaps_for_dict from .wd14 import get_wd14_tags diff --git a/imgutils/tagging/deepdanbooru.py b/imgutils/tagging/deepdanbooru.py index 5c8c7a69029..afae2e4932b 100644 --- a/imgutils/tagging/deepdanbooru.py +++ b/imgutils/tagging/deepdanbooru.py @@ -16,6 +16,7 @@ from PIL import Image from huggingface_hub import hf_hub_download +from .overlap import drop_overlaps_for_dict from ..data import ImageTyping, load_image from ..utils import open_onnx_model @@ -31,7 +32,7 @@ def _get_deepdanbooru_labels(): general_indexes = list(np.where(df["category"] == 0)[0]) character_indexes = list(np.where(df["category"] == 4)[0]) return tag_names, tag_real_names, \ - rating_indexes, general_indexes, character_indexes + rating_indexes, general_indexes, character_indexes @lru_cache() @@ -61,7 +62,8 @@ def _image_preprocess(image: Image.Image) -> np.ndarray: def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False, - general_threshold: float = 0.5, character_threshold: float = 0.5): + general_threshold: float = 0.5, character_threshold: float = 0.5, + drop_overlap: bool = False): """ Overview: Get tags for anime image based on ``deepdanbooru`` model. @@ -73,6 +75,7 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False, The default value of ``False`` indicates the use of the original tag names. :param general_threshold: Threshold for default tags, default is ``0.35``. :param character_threshold: Threshold for character tags, default is ``0.85``. + :param drop_overlap: Drop overlap tags or not, default is ``False``. :return: Tagging results for levels, features and characters. Example: @@ -120,6 +123,8 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False, general_names = [labels[i] for i in general_indexes] general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) + if drop_overlap: + general_res = drop_overlaps_for_dict(general_res) # Everything else is characters: pick anywhere prediction confidence > threshold character_names = [labels[i] for i in character_indexes] diff --git a/imgutils/tagging/mldanbooru.py b/imgutils/tagging/mldanbooru.py index 7b1f28b81a6..de7741472fa 100644 --- a/imgutils/tagging/mldanbooru.py +++ b/imgutils/tagging/mldanbooru.py @@ -11,6 +11,7 @@ from PIL import Image from huggingface_hub import hf_hub_download +from .overlap import drop_overlaps_for_dict from ..data import load_image, ImageTyping from ..utils import open_onnx_model @@ -57,7 +58,8 @@ def _get_mldanbooru_labels(use_real_name: bool = False) -> Tuple[List[str], List def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False, - threshold: float = 0.7, size: int = 448, keep_ratio: bool = False): + threshold: float = 0.7, size: int = 448, keep_ratio: bool = False, + drop_overlap: bool = False): """ Overview: Tagging image with ML-Danbooru, similar to @@ -72,6 +74,7 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False, :param size: Size when passing the resized image into model, default is ``448``. :param keep_ratio: Keep the original ratio between height and width when passing the image into model, default is ``False``. + :param drop_overlap: Drop overlap tags or not, default is ``False``. Example: Here are some images for example @@ -103,4 +106,8 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False, output = (1 / (1 + np.exp(-native_output))).reshape(-1) tags = _get_mldanbooru_labels(use_real_name) pairs = sorted([(tags[i], ratio) for i, ratio in enumerate(output)], key=lambda x: (-x[1], x[0])) - return {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold} + + general_tags = {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold} + if drop_overlap: + general_tags = drop_overlaps_for_dict(general_tags) + return general_tags diff --git a/imgutils/tagging/overlap.py b/imgutils/tagging/overlap.py new file mode 100644 index 00000000000..47bddc7893c --- /dev/null +++ b/imgutils/tagging/overlap.py @@ -0,0 +1,113 @@ +import json +from functools import lru_cache +from typing import Mapping, List + +from huggingface_hub import hf_hub_download + + +@lru_cache() +def _get_overlap_tags() -> Mapping[str, List[str]]: + """ + Retrieve the overlap tag information from the specified Hugging Face Hub repository. + + This function downloads a JSON file containing tag overlap information and parses it into a dictionary. + + :return: A dictionary where keys are tags and values are lists of overlapping tags. + :rtype: Mapping[str, List[str]] + """ + json_file = hf_hub_download( + 'alea31415/tag_filtering', + 'overlap_tags.json', + repo_type='dataset', + ) + with open(json_file, 'r') as file: + data = json.load(file) + + return { + entry['query']: entry['has_overlap'] + for entry in data if 'has_overlap' in entry and entry['has_overlap'] + } + + +def drop_overlap_tags(tags: List[str]) -> List[str]: + """ + Drop overlapping tags from the given list of tags. + + This function removes tags that have overlaps with other tags based on precomputed overlap information. + + :param tags: A list of tags. + :type tags: List[str] + :return: A list of tags without overlaps. + :rtype: List[str] + + Examples:: + >>> from imgutils.tagging import drop_overlap_tags + >>> + >>> tags = [ + ... '1girl', 'solo', + ... 'long_hair', 'very_long_hair', 'red_hair', + ... 'breasts', 'medium_breasts', + ... ] + >>> drop_overlap_tags(tags) + ['1girl', 'solo', 'very_long_hair', 'red_hair', 'medium_breasts'] + """ + overlap_tags_dict = _get_overlap_tags() + result_tags = [] + tags_underscore = [tag.replace(' ', '_') for tag in tags] + + for tag, tag_ in zip(tags, tags_underscore): + + to_remove = False + + # Case 1: If the tag is a key and some of the associated values are in tags + if tag_ in overlap_tags_dict: + overlap_values = set(val for val in overlap_tags_dict[tag_]) + if overlap_values.intersection(set(tags_underscore)): + to_remove = True + + # Checking superword condition separately + for tag_another in tags: + if tag in tag_another and tag != tag_another: + to_remove = True + break + + if not to_remove: + result_tags.append(tag) + + return result_tags + + +def drop_overlaps_for_dict(tags: Mapping[str, float]) -> Mapping[str, float]: + """ + Drop overlapping tags from the given dictionary of tags with confidence scores. + + This function removes tags that have overlaps with other tags based on precomputed overlap information. + + :param tags: A dictionary where keys are tags and values are confidence scores. + :type tags: Mapping[str, float] + :return: A dictionary with non-overlapping tags and their corresponding confidence scores. + :rtype: Mapping[str, float] + + Examples:: + >>> from imgutils.tagging import drop_overlaps_for_dict + >>> + >>> tags = { + ... '1girl': 0.8849405313291128, + ... 'solo': 0.8548297594823425, + ... 'long_hair': 0.03910296474461261, + ... 'very_long_hair': 0.6615180440330748, + ... 'red_hair': 0.21552028866308015, + ... 'breasts': 0.3165260620737027, + ... 'medium_breasts': 0.47744464927382957, + ... } + >>> drop_overlaps_for_dict(tags) + { + '1girl': 0.8849405313291128, + 'solo': 0.8548297594823425, + 'very_long_hair': 0.6615180440330748, + 'red_hair': 0.21552028866308015, + 'medium_breasts': 0.47744464927382957 + } + """ + key_set = set(drop_overlap_tags(list(tags.keys()))) + return {tag: confidence for tag, confidence in tags.items() if tag in key_set} diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index c24f16e1326..63165919d33 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -11,6 +11,7 @@ import numpy as np import pandas as pd +from .overlap import drop_overlaps_for_dict from ..data import load_image, ImageTyping from ..utils import open_onnx_model @@ -83,7 +84,8 @@ def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]: def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", - general_threshold: float = 0.35, character_threshold: float = 0.85): + general_threshold: float = 0.35, character_threshold: float = 0.85, + drop_overlap: bool = False): """ Overview: Tagging image by wd14 v2 model. Similar to @@ -94,6 +96,7 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", ``SwinV2``, ``ConvNext``, ``ConvNextV2``, ``ViT`` or ``MOAT``, default is ``ConvNextV2``. :param general_threshold: Threshold for default tags, default is ``0.35``. :param character_threshold: Threshold for character tags, default is ``0.85``. + :param drop_overlap: Drop overlap tags or not, default is ``False``. :return: Tagging results for levels, features and characters. Example: @@ -148,6 +151,8 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", general_names = [labels[i] for i in general_indexes] general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) + if drop_overlap: + general_res = drop_overlaps_for_dict(general_res) # Everything else is characters: pick anywhere prediction confidence > threshold character_names = [labels[i] for i in character_indexes] diff --git a/imgutils/utils/onnxruntime.py b/imgutils/utils/onnxruntime.py index 455ee892afe..22a37a11b9d 100644 --- a/imgutils/utils/onnxruntime.py +++ b/imgutils/utils/onnxruntime.py @@ -63,14 +63,18 @@ def get_onnx_provider(provider: Optional[str] = None): f'but unsupported provider {provider!r} found.') -def _open_onnx_model(ckpt: str, provider: str) -> InferenceSession: +def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True) -> InferenceSession: options = SessionOptions() options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL if provider == "CPUExecutionProvider": options.intra_op_num_threads = os.cpu_count() + providers = [provider] + if use_cpu and "CPUExecutionProvider" not in providers: + providers.append("CPUExecutionProvider") + logging.info(f'Model {ckpt!r} loaded with provider {provider!r}') - return InferenceSession(ckpt, options, [provider]) + return InferenceSession(ckpt, options, providers=providers) def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession: diff --git a/test/detect/test_censor.py b/test/detect/test_censor.py index d711c160ccb..5025ea15e10 100644 --- a/test/detect/test_censor.py +++ b/test/detect/test_censor.py @@ -13,7 +13,7 @@ def _release_model_after_run(): @pytest.mark.unittest -class TestDetectHead: +class TestDetectCensor: def test_detect_censors(self): detections = detect_censors(get_testfile('nude_girl.png')) assert len(detections) == 3 diff --git a/test/detect/test_eye.py b/test/detect/test_eye.py new file mode 100644 index 00000000000..85c88dd7424 --- /dev/null +++ b/test/detect/test_eye.py @@ -0,0 +1,32 @@ +import pytest + +from imgutils.detect.eye import _open_eye_detect_model, detect_eyes +from test.testings import get_testfile + + +@pytest.fixture(scope='module', autouse=True) +def _release_model_after_run(): + try: + yield + finally: + _open_eye_detect_model.cache_clear() + + +@pytest.mark.unittest +class TestDetectEyes: + def test_detect_eye(self): + detections = detect_eyes(get_testfile('nude_girl.png')) + assert len(detections) == 2 + + values = [] + for bbox, label, score in detections: + assert label in {'eye'} + values.append((bbox, int(score * 1000) / 1000)) + + assert values == pytest.approx([ + ((350, 160, 382, 173), 0.788), + ((294, 170, 319, 181), 0.756), + ]) + + def test_detect_eye_none(self): + assert detect_eyes(get_testfile('png_full.png')) == [] diff --git a/test/detect/test_halfbody.py b/test/detect/test_halfbody.py index e10091ebd10..2371b80e0eb 100644 --- a/test/detect/test_halfbody.py +++ b/test/detect/test_halfbody.py @@ -13,7 +13,7 @@ def _release_model_after_run(): @pytest.mark.unittest -class TestDetectHead: +class TestDetectHalfBody: def test_detect_halfbody(self): detections = detect_halfbody(get_testfile('nude_girl.png')) assert len(detections) == 1 diff --git a/test/detect/test_text.py b/test/detect/test_text.py new file mode 100644 index 00000000000..e2092b5ecd1 --- /dev/null +++ b/test/detect/test_text.py @@ -0,0 +1,49 @@ +import pytest + +from imgutils.detect.text import _open_text_detect_model, detect_text +from test.testings import get_testfile + + +@pytest.fixture(scope='module', autouse=True) +def _release_model_after_run(): + try: + yield + finally: + _open_text_detect_model.cache_clear() + + +@pytest.mark.unittest +class TestDetectText: + def test_detect_text(self): + detections = detect_text(get_testfile('ml1.png')) + assert len(detections) == 4 + + values = [] + for bbox, label, score in detections: + assert label in {'text'} + values.append((bbox, int(score * 1000) / 1000)) + + assert values == pytest.approx([ + ((866, 45, 959, 69), 0.543), + ((222, 68, 313, 102), 0.543), + ((424, 82, 508, 113), 0.541), + ((691, 101, 776, 129), 0.471) + ]) + + def test_detect_text_without_resize(self): + detections = detect_text(get_testfile('ml2.jpg'), max_area_size=None) + assert len(detections) == 9 + + values = [] + for bbox, label, score in detections: + assert label in {'text'} + values.append((bbox, int(score * 1000) / 1000)) + + assert values == pytest.approx([ + ((360, 218, 474, 250), 0.686), ((119, 218, 203, 240), 0.653), ((392, 47, 466, 76), 0.617), + ((593, 174, 666, 204), 0.616), ((179, 451, 672, 472), 0.591), ((633, 314, 747, 337), 0.59), + ((392, 369, 517, 386), 0.589), ((621, 81, 681, 102), 0.566), ((209, 92, 281, 122), 0.423), + ]) + + def test_detect_text_none(self): + assert detect_text(get_testfile('png_full.png')) == [] diff --git a/test/tagging/test_deepdanbooru.py b/test/tagging/test_deepdanbooru.py index 168448f88f8..eb53be0fc2c 100644 --- a/test/tagging/test_deepdanbooru.py +++ b/test/tagging/test_deepdanbooru.py @@ -5,7 +5,7 @@ from test.testings import get_testfile -@pytest.fixture() +@pytest.fixture(autouse=True, scope='module') def _release_model_after_run(): try: yield @@ -15,7 +15,7 @@ def _release_model_after_run(): @pytest.mark.unittest class TestTaggingDeepdanbooru: - def test_get_deepdanbooru_tags(self, _release_model_after_run): + def test_get_deepdanbooru_tags(self): rating, tags, chars = get_deepdanbooru_tags(get_testfile('6124220.jpg')) assert rating['rating:safe'] > 0.9 assert tags['greyscale'] >= 0.8 @@ -27,3 +27,47 @@ def test_get_deepdanbooru_tags(self, _release_model_after_run): assert tags['1girl'] >= 0.85 assert tags['ring'] > 0.8 assert chars['hu_tao_(genshin_impact)'] >= 0.7 + + def test_get_danbooru_tags_sample(self): + rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png')) + assert rating == pytest.approx({ + 'rating:safe': 8.940696716308594e-06, + 'rating:questionable': 0.012878596782684326, + 'rating:explicit': 0.992286205291748, + }, abs=1e-3) + assert tags == pytest.approx({ + '1girl': 0.9923416376113892, 'armpits': 0.9226008653640747, 'arms_behind_head': 0.5620371699333191, + 'arms_up': 0.7268614172935486, 'bangs': 0.7465004920959473, 'black_border': 0.9081975221633911, + 'blush': 0.9306209683418274, 'breasts': 0.9972158670425415, + 'eyebrows_visible_through_hair': 0.6717097163200378, 'hair_between_eyes': 0.7044132947921753, + 'hair_intakes': 0.6295598745346069, 'horns': 0.9387356042861938, 'letterboxed': 1.0, + 'long_hair': 0.9871174693107605, 'looking_at_viewer': 0.8953969478607178, + 'medium_breasts': 0.90318363904953, 'navel': 0.9425054788589478, 'nipples': 0.9989081621170044, + 'nude': 0.9452821016311646, 'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120401501655579, + 'pussy': 0.9943721294403076, 'pussy_juice': 0.8238061666488647, 'red_hair': 0.9203640222549438, + 'smile': 0.6659414172172546, 'solo': 0.9483305811882019, 'spread_legs': 0.7633067965507507, + 'stomach': 0.5396291017532349, 'sweat': 0.7880321145057678, 'thighs': 0.7451953291893005, + 'uncensored': 0.9594683647155762, 'very_long_hair': 0.740706205368042, + }, abs=1e-3) + assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3) + + def test_get_danbooru_tags_drop_overlap(self): + rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png'), drop_overlap=True) + assert rating == pytest.approx({ + 'rating:safe': 8.940696716308594e-06, + 'rating:questionable': 0.012878596782684326, + 'rating:explicit': 0.992286205291748, + }, abs=1e-3) + assert tags == pytest.approx({ + '1girl': 0.9923416376113892, 'armpits': 0.9226007461547852, 'arms_behind_head': 0.5620364546775818, + 'arms_up': 0.7268615365028381, 'bangs': 0.7465004324913025, 'black_border': 0.9081975221633911, + 'blush': 0.9306209683418274, 'eyebrows_visible_through_hair': 0.6717095971107483, + 'hair_between_eyes': 0.7044129967689514, 'hair_intakes': 0.6295579671859741, 'horns': 0.938735842704773, + 'letterboxed': 1.0, 'looking_at_viewer': 0.8953973650932312, 'medium_breasts': 0.9031840562820435, + 'navel': 0.9425054788589478, 'nipples': 0.9989081621170044, 'nude': 0.9452821016311646, + 'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120403289794922, 'pussy_juice': 0.8238056898117065, + 'red_hair': 0.9203639030456543, 'smile': 0.6659414172172546, 'solo': 0.948330819606781, + 'spread_legs': 0.7633066177368164, 'stomach': 0.5396295189857483, 'sweat': 0.7880324721336365, + 'thighs': 0.745195746421814, 'uncensored': 0.9594683647155762, 'very_long_hair': 0.7407056093215942 + }, abs=1e-3) + assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3) diff --git a/test/tagging/test_mldanbooru.py b/test/tagging/test_mldanbooru.py index c1331aa9470..ecfcde61015 100644 --- a/test/tagging/test_mldanbooru.py +++ b/test/tagging/test_mldanbooru.py @@ -5,7 +5,7 @@ from test.testings import get_testfile -@pytest.fixture() +@pytest.fixture(autouse=True, scope='module') def _release_model_after_run(): try: yield @@ -22,3 +22,42 @@ def test_get_mldanbooru_tags(self, keep_ratio): tags = get_mldanbooru_tags(get_testfile('6125785.jpg'), keep_ratio=keep_ratio) assert tags['1girl'] >= 0.95 + + def test_get_mldanbooru_tags_sample(self): + tags = get_mldanbooru_tags(get_testfile('nude_girl.png')) + assert tags == pytest.approx({ + '1girl': 0.9999977350234985, 'breasts': 0.999940037727356, 'nipples': 0.999920129776001, + 'solo': 0.9993574023246765, 'pussy': 0.9993218183517456, 'horns': 0.9977452158927917, + 'nude': 0.995971143245697, 'purple_eyes': 0.9957809448242188, 'long_hair': 0.9929291605949402, + 'navel': 0.9814828038215637, 'armpits': 0.9808009266853333, 'spread_legs': 0.9767358303070068, + 'pussy_juice': 0.959962785243988, 'blush': 0.9482676386833191, 'uncensored': 0.9446588158607483, + 'looking_at_viewer': 0.9295657873153687, 'red_hair': 0.919776201248169, + 'medium_breasts': 0.9020175337791443, 'completely_nude': 0.8965569138526917, 'arms_up': 0.8882529139518738, + 'on_back': 0.8701885342597961, 'arms_behind_head': 0.8692260980606079, 'lying': 0.8653205037117004, + 'pillow': 0.8645844459533691, 'bangs': 0.8618668913841248, 'smile': 0.8531544804573059, + 'very_long_hair': 0.8332053422927856, 'pointy_ears': 0.8194612264633179, 'stomach': 0.8194073438644409, + 'hair_intakes': 0.8191318511962891, 'on_bed': 0.8055890202522278, 'sweat': 0.7933878302574158, + 'thighs': 0.7835342884063721, 'hair_between_eyes': 0.7693091630935669, + 'eyebrows_visible_through_hair': 0.7672545313835144, 'closed_mouth': 0.7638942003250122, + 'breasts_apart': 0.7527053952217102, 'bed': 0.7515304088592529, 'slit_pupils': 0.7464283108711243, + 'barefoot': 0.7429600954055786, 'bed_sheet': 0.7186222076416016, 'fang': 0.7162102460861206, + 'clitoris': 0.7013473510742188, + }, abs=1e-3) + + def test_get_mldanbooru_tags_no_overlap(self): + tags = get_mldanbooru_tags(get_testfile('nude_girl.png'), drop_overlap=True) + assert tags == pytest.approx({ + '1girl': 0.9999977350234985, 'nipples': 0.999920129776001, 'solo': 0.9993574023246765, + 'horns': 0.9977452158927917, 'purple_eyes': 0.9957809448242188, 'navel': 0.9814828038215637, + 'armpits': 0.9808009266853333, 'spread_legs': 0.9767358303070068, 'pussy_juice': 0.959962785243988, + 'blush': 0.9482676386833191, 'uncensored': 0.9446586966514587, 'looking_at_viewer': 0.9295657873153687, + 'red_hair': 0.9197760820388794, 'medium_breasts': 0.9020175337791443, 'completely_nude': 0.8965569138526917, + 'arms_up': 0.8882529139518738, 'on_back': 0.8701885342597961, 'arms_behind_head': 0.8692260980606079, + 'pillow': 0.8645844459533691, 'bangs': 0.8618668913841248, 'smile': 0.8531544804573059, + 'very_long_hair': 0.8332052230834961, 'pointy_ears': 0.8194612264633179, 'stomach': 0.8194073438644409, + 'hair_intakes': 0.8191318511962891, 'on_bed': 0.8055890202522278, 'sweat': 0.793387770652771, + 'thighs': 0.7835341691970825, 'hair_between_eyes': 0.7693091034889221, + 'eyebrows_visible_through_hair': 0.7672545909881592, 'closed_mouth': 0.7638942003250122, + 'breasts_apart': 0.7527053356170654, 'slit_pupils': 0.7464284300804138, 'barefoot': 0.7429600358009338, + 'bed_sheet': 0.7186222672462463, 'fang': 0.7162103652954102, 'clitoris': 0.7013473510742188 + }, abs=1e-3) diff --git a/test/tagging/test_overlap.py b/test/tagging/test_overlap.py new file mode 100644 index 00000000000..6535fc7e840 --- /dev/null +++ b/test/tagging/test_overlap.py @@ -0,0 +1,42 @@ +import pytest + +from imgutils.tagging import drop_overlaps_for_dict, drop_overlap_tags + + +@pytest.fixture() +def complex_dict_tags(): + return { + '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507, + 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, + 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, + 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, + 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, + 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317, + 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223, + 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, + 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, + 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, + 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901, + } + + +@pytest.mark.unittest +class TestTaggingOverlap: + def test_drop_overlap_tags(self): + assert drop_overlap_tags(['1girl', 'solo', 'long_hair', 'very_long_hair', 'red_hair']) == \ + ['1girl', 'solo', 'very_long_hair', 'red_hair'] + + def test_drop_overlaps_for_dict_complex(self, complex_dict_tags): + assert drop_overlaps_for_dict(complex_dict_tags) == pytest.approx({ + '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698, + 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, + 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, + 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, + 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317, + 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658, + 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, + 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, + 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, + 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, + 'clitoris': 0.5310801267623901 + }) diff --git a/test/tagging/test_wd14.py b/test/tagging/test_wd14.py index 3318795a64b..897cf9e8729 100644 --- a/test/tagging/test_wd14.py +++ b/test/tagging/test_wd14.py @@ -5,7 +5,7 @@ from test.testings import get_testfile -@pytest.fixture() +@pytest.fixture(autouse=True, scope='module') def _release_model_after_run(): try: yield @@ -26,3 +26,51 @@ def test_get_wd14_tags(self): assert 0.35 <= rating['sensitive'] <= 0.45 assert tags['1girl'] >= 0.95 assert chars['hu_tao_(genshin_impact)'] >= 0.95 + + def test_wd14_tags_sample(self): + rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png')) + assert rating == pytest.approx({ + 'general': 0.0020540356636047363, + 'sensitive': 0.0080718994140625, + 'questionable': 0.003170192241668701, + 'explicit': 0.984081506729126, + }, abs=1e-3) + assert tags == pytest.approx({ + '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507, + 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, + 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, + 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, + 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, + 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317, + 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223, + 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, + 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, + 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, + 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, + 'clitoris': 0.5310801267623901 + }, abs=1e-3) + assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) + + def test_wd14_tags_no_overlap(self): + rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), drop_overlap=True) + # print(tags) + assert rating == pytest.approx({ + 'general': 0.0020540356636047363, + 'sensitive': 0.0080718994140625, + 'questionable': 0.003170192241668701, + 'explicit': 0.984081506729126, + }, abs=1e-3) + assert tags == pytest.approx({ + '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698, + 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, + 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, + 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, + 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317, + 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658, + 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, + 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, + 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, + 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, + 'clitoris': 0.5310801267623901 + }, abs=1e-3) + assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) diff --git a/test/testfile/ml1.png b/test/testfile/ml1.png new file mode 100644 index 00000000000..7fbd587cecf Binary files /dev/null and b/test/testfile/ml1.png differ diff --git a/test/testfile/ml2.jpg b/test/testfile/ml2.jpg new file mode 100644 index 00000000000..4ad18461bd3 Binary files /dev/null and b/test/testfile/ml2.jpg differ