diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml
index 01f7fc824cc..8c563d64bb5 100644
--- a/.github/workflows/doc.yml
+++ b/.github/workflows/doc.yml
@@ -50,8 +50,6 @@ jobs:
if: ${{ github.event_name == 'push' }}
env:
CI: 'true'
- HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }}
- HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }}
with:
shell: bash
timeout_minutes: 20
@@ -122,8 +120,6 @@ jobs:
uses: nick-fields/retry@v2
env:
CI: 'true'
- HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }}
- HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }}
with:
shell: bash
timeout_minutes: 20
diff --git a/.github/workflows/export.yml b/.github/workflows/export.yml
index c99de51d347..021c9af9173 100644
--- a/.github/workflows/export.yml
+++ b/.github/workflows/export.yml
@@ -44,8 +44,6 @@ jobs:
uses: nick-fields/retry@v2
env:
CI: 'true'
- HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }}
- HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }}
with:
shell: bash
timeout_minutes: 20
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 547e8740a45..eee99a49181 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -103,8 +103,6 @@ jobs:
uses: nick-fields/retry@v2
env:
CI: 'true'
- HF_NARUGO_USERNAME: ${{ secrets.HF_NARUGO_USERNAME }}
- HF_NARUGO_PASSWORD: ${{ secrets.HF_NARUGO_PASSWORD }}
with:
shell: bash
timeout_minutes: 20
diff --git a/Makefile b/Makefile
index 9b2aa145493..a8fab6ade5c 100644
--- a/Makefile
+++ b/Makefile
@@ -50,10 +50,10 @@ pdocs:
dataset:
mkdir -p ${DATASET_DIR}
if [ ! -d ${DATASET_DIR}/chafen_arknights ]; then \
- git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \
+ git clone https://huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \
fi
if [ ! -d ${DATASET_DIR}/monochrome_danbooru ]; then \
- git clone https://${HF_NARUGO_USERNAME}:${HF_NARUGO_PASSWORD}@huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \
+ git clone https://huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \
fi
if [ ! -d ${DATASET_DIR}/images_test_v1 ]; then \
mkdir -p ${DATASET_DIR}/images_test_v1 && \
diff --git a/docs/source/_libs/plot.py b/docs/source/_libs/plot.py
index 87bcd71ebc3..8afe3ff135f 100644
--- a/docs/source/_libs/plot.py
+++ b/docs/source/_libs/plot.py
@@ -1,6 +1,7 @@
from typing import Tuple, List
import matplotlib.pyplot as plt
+import numpy as np
from PIL import Image
from cli import _wrap_func_as_cli
@@ -44,6 +45,14 @@ def image_plot(*images, save_as: str, columns=2, keep_axis: bool = False, figsiz
n = len(images)
rows = (n + columns - 1) // columns
fig, axs = plt.subplots(rows, columns, figsize=figsize)
+ if rows == 1 and columns == 1:
+ axs = np.array([[axs]])
+ elif rows == 1:
+ axs = axs[None, ...]
+ elif columns == 1:
+ axs = axs[..., None]
+ else:
+ pass
plt.subplots_adjust(wspace=0.2, hspace=0.15)
for i, img in enumerate(images, start=0):
xi, yi = i // columns, i % columns
diff --git a/docs/source/api_doc/detect/eye.rst b/docs/source/api_doc/detect/eye.rst
new file mode 100644
index 00000000000..98f39be8362
--- /dev/null
+++ b/docs/source/api_doc/detect/eye.rst
@@ -0,0 +1,14 @@
+imgutils.detect.eye
+==========================
+
+.. currentmodule:: imgutils.detect.eye
+
+.. automodule:: imgutils.detect.eye
+
+
+detect_eyes
+------------------------------
+
+.. autofunction:: detect_eyes
+
+
diff --git a/docs/source/api_doc/detect/eye_detect_benchmark.plot.py b/docs/source/api_doc/detect/eye_detect_benchmark.plot.py
new file mode 100644
index 00000000000..f4c31266ad7
--- /dev/null
+++ b/docs/source/api_doc/detect/eye_detect_benchmark.plot.py
@@ -0,0 +1,37 @@
+import random
+
+from benchmark import BaseBenchmark, create_plot_cli
+from imgutils.detect import detect_eyes
+
+
+class EyeDetectBenchmark(BaseBenchmark):
+ def __init__(self, level, version):
+ BaseBenchmark.__init__(self)
+ self.level = level
+ self.version = version
+
+ def load(self):
+ from imgutils.detect.eye import _open_eye_detect_model
+ _ = _open_eye_detect_model(level=self.level, version=self.version)
+
+ def unload(self):
+ from imgutils.detect.eye import _open_eye_detect_model
+ _open_eye_detect_model.cache_clear()
+
+ def run(self):
+ image_file = random.choice(self.all_images)
+ _ = detect_eyes(image_file, level=self.level, version=self.version)
+
+
+if __name__ == '__main__':
+ create_plot_cli(
+ [
+ ('eye v1.0 (yolov8s)', EyeDetectBenchmark('s', 'v1.0')),
+ ('eye v1.0 (yolov8n)', EyeDetectBenchmark('n', 'v1.0')),
+ ('eye v0.8 (yolov8s)', EyeDetectBenchmark('s', 'v0.8')),
+ ('eye v0.7 (yolov8s)', EyeDetectBenchmark('s', 'v0.7')),
+ ],
+ title='Benchmark for Anime Eyes Detections',
+ run_times=10,
+ try_times=20,
+ )()
diff --git a/docs/source/api_doc/detect/eye_detect_benchmark.plot.py.svg b/docs/source/api_doc/detect/eye_detect_benchmark.plot.py.svg
new file mode 100644
index 00000000000..1679499c798
--- /dev/null
+++ b/docs/source/api_doc/detect/eye_detect_benchmark.plot.py.svg
@@ -0,0 +1,2546 @@
+
+
+
diff --git a/docs/source/api_doc/detect/eye_detect_demo.plot.py b/docs/source/api_doc/detect/eye_detect_demo.plot.py
new file mode 100644
index 00000000000..675eac261cf
--- /dev/null
+++ b/docs/source/api_doc/detect/eye_detect_demo.plot.py
@@ -0,0 +1,19 @@
+from imgutils.detect import detect_eyes
+from imgutils.detect.eye import _LABELS
+from imgutils.detect.visual import detection_visualize
+from plot import image_plot
+
+
+def _detect(img, **kwargs):
+ return detection_visualize(img, detect_eyes(img, **kwargs), _LABELS)
+
+
+if __name__ == '__main__':
+ image_plot(
+ (_detect('nian.png'), 'large scale'),
+ (_detect('two_bikini_girls.png'), 'closed heads'),
+ (_detect('halfbody/squat.jpg'), 'pose'),
+ (_detect('mostima_post.jpg'), 'multiple'),
+ columns=2,
+ figsize=(10, 9),
+ )
diff --git a/docs/source/api_doc/detect/eye_detect_demo.plot.py.svg b/docs/source/api_doc/detect/eye_detect_demo.plot.py.svg
new file mode 100644
index 00000000000..67ec6b3cf6e
--- /dev/null
+++ b/docs/source/api_doc/detect/eye_detect_demo.plot.py.svg
@@ -0,0 +1,481 @@
+
+
+
diff --git a/docs/source/api_doc/detect/index.rst b/docs/source/api_doc/detect/index.rst
index 4e8c53e77c5..48a1360957d 100644
--- a/docs/source/api_doc/detect/index.rst
+++ b/docs/source/api_doc/detect/index.rst
@@ -10,10 +10,12 @@ imgutils.detect
:maxdepth: 3
censor
+ eye
face
halfbody
hand
head
person
+ text
visual
diff --git a/docs/source/api_doc/detect/text.rst b/docs/source/api_doc/detect/text.rst
new file mode 100644
index 00000000000..dd05bf7e97d
--- /dev/null
+++ b/docs/source/api_doc/detect/text.rst
@@ -0,0 +1,14 @@
+imgutils.detect.text
+==========================
+
+.. currentmodule:: imgutils.detect.text
+
+.. automodule:: imgutils.detect.text
+
+
+detect_text
+------------------------------
+
+.. autofunction:: detect_text
+
+
diff --git a/docs/source/api_doc/detect/text/ml1.png b/docs/source/api_doc/detect/text/ml1.png
new file mode 100644
index 00000000000..7fbd587cecf
Binary files /dev/null and b/docs/source/api_doc/detect/text/ml1.png differ
diff --git a/docs/source/api_doc/detect/text/ml2.jpg b/docs/source/api_doc/detect/text/ml2.jpg
new file mode 100644
index 00000000000..4ad18461bd3
Binary files /dev/null and b/docs/source/api_doc/detect/text/ml2.jpg differ
diff --git a/docs/source/api_doc/detect/text_detect_benchmark.plot.py b/docs/source/api_doc/detect/text_detect_benchmark.plot.py
new file mode 100644
index 00000000000..8866f1159c7
--- /dev/null
+++ b/docs/source/api_doc/detect/text_detect_benchmark.plot.py
@@ -0,0 +1,53 @@
+import random
+
+from benchmark import BaseBenchmark, create_plot_cli
+
+from imgutils.detect import detect_text
+
+
+class TextDetectBenchmark(BaseBenchmark):
+ def __init__(self, model):
+ BaseBenchmark.__init__(self)
+ self.model = model
+
+ def load(self):
+ from imgutils.detect.text import _open_text_detect_model
+ _ = _open_text_detect_model(self.model)
+
+ def unload(self):
+ from imgutils.detect.text import _open_text_detect_model
+ _open_text_detect_model.cache_clear()
+
+ def run(self):
+ image_file = random.choice(self.all_images)
+ _ = detect_text(image_file, model=self.model)
+
+
+if __name__ == '__main__':
+ create_plot_cli(
+ [
+ (
+ 'dbnet_resnet18_fpnc_1200e_icdar2015',
+ TextDetectBenchmark('dbnet_resnet18_fpnc_1200e_icdar2015')
+ ),
+ (
+ 'dbnet_resnet18_fpnc_1200e_totaltext',
+ TextDetectBenchmark('dbnet_resnet18_fpnc_1200e_totaltext')
+ ),
+ (
+ 'dbnet_resnet50-oclip_fpnc_1200e_icdar2015',
+ TextDetectBenchmark('dbnet_resnet50-oclip_fpnc_1200e_icdar2015')
+ ),
+ (
+ 'dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015',
+ TextDetectBenchmark('dbnetpp_resnet50-oclip_fpnc_1200e_icdar2015')
+ ),
+ (
+ 'dbnetpp_resnet50_fpnc_1200e_icdar2015',
+ TextDetectBenchmark('dbnetpp_resnet50_fpnc_1200e_icdar2015')
+ ),
+ ],
+ title='Benchmark for Text Detections',
+ run_times=10,
+ try_times=20,
+ )()
diff --git a/docs/source/api_doc/detect/text_detect_benchmark.plot.py.svg b/docs/source/api_doc/detect/text_detect_benchmark.plot.py.svg
new file mode 100644
index 00000000000..7bb3cc82577
--- /dev/null
+++ b/docs/source/api_doc/detect/text_detect_benchmark.plot.py.svg
@@ -0,0 +1,2804 @@
+
+
+
diff --git a/docs/source/api_doc/detect/text_detect_demo.plot.py b/docs/source/api_doc/detect/text_detect_demo.plot.py
new file mode 100644
index 00000000000..6be505b8f5c
--- /dev/null
+++ b/docs/source/api_doc/detect/text_detect_demo.plot.py
@@ -0,0 +1,17 @@
+from plot import image_plot
+
+from imgutils.detect import detect_text
+from imgutils.detect.visual import detection_visualize
+
+
+def _detect(img, **kwargs):
+ return detection_visualize(img, detect_text(img, **kwargs))
+
+
+if __name__ == '__main__':
+ image_plot(
+ (_detect('text/ml1.png'), 'Multiple Languages I'),
+ (_detect('text/ml2.jpg'), 'Multiple Languages II'),
+ columns=1,
+ figsize=(8, 9),
+ )
diff --git a/docs/source/api_doc/detect/text_detect_demo.plot.py.svg b/docs/source/api_doc/detect/text_detect_demo.plot.py.svg
new file mode 100644
index 00000000000..093304b2704
--- /dev/null
+++ b/docs/source/api_doc/detect/text_detect_demo.plot.py.svg
@@ -0,0 +1,370 @@
+
+
+
diff --git a/docs/source/api_doc/tagging/index.rst b/docs/source/api_doc/tagging/index.rst
index 7a67a0e87c7..989a0928135 100644
--- a/docs/source/api_doc/tagging/index.rst
+++ b/docs/source/api_doc/tagging/index.rst
@@ -13,3 +13,4 @@ imgutils.tagging
wd14
deepdanbooru
format
+ overlap
diff --git a/docs/source/api_doc/tagging/overlap.rst b/docs/source/api_doc/tagging/overlap.rst
new file mode 100644
index 00000000000..f62b0a44daf
--- /dev/null
+++ b/docs/source/api_doc/tagging/overlap.rst
@@ -0,0 +1,22 @@
+imgutils.tagging.overlap
+====================================
+
+.. currentmodule:: imgutils.tagging.overlap
+
+.. automodule:: imgutils.tagging.overlap
+
+
+drop_overlap_tags
+----------------------------------
+
+.. autofunction:: drop_overlap_tags
+
+
+
+drop_overlaps_for_dict
+----------------------------------
+
+.. autofunction:: drop_overlaps_for_dict
+
+
+
diff --git a/imgutils/config/meta.py b/imgutils/config/meta.py
index db12b9657c2..2526cbe12c5 100644
--- a/imgutils/config/meta.py
+++ b/imgutils/config/meta.py
@@ -7,7 +7,7 @@
__TITLE__ = 'imgutils'
#: Version of this project.
-__VERSION__ = '0.2.5'
+__VERSION__ = '0.2.7'
#: Short description of the project, will be included in ``setup.py``.
__DESCRIPTION__ = 'A convenient and user-friendly anime-style image data processing library that integrates ' \
diff --git a/imgutils/detect/__init__.py b/imgutils/detect/__init__.py
index fe37f41116b..2252e8641c8 100644
--- a/imgutils/detect/__init__.py
+++ b/imgutils/detect/__init__.py
@@ -9,9 +9,11 @@
:align: center
"""
from .censor import detect_censors
+from .eye import detect_eyes
from .face import detect_faces
from .halfbody import detect_halfbody
from .hand import detect_hands
from .head import detect_heads
from .person import detect_person
+from .text import detect_text
from .visual import detection_visualize
diff --git a/imgutils/detect/eye.py b/imgutils/detect/eye.py
new file mode 100644
index 00000000000..ace44fc87b0
--- /dev/null
+++ b/imgutils/detect/eye.py
@@ -0,0 +1,76 @@
+"""
+Overview:
+ Detect eyes in anime images.
+
+ Trained on dataset `deepghs/anime_eye_detection `_ with YOLOv8.
+
+ .. image:: eye_detect_demo.plot.py.svg
+ :align: center
+
+ This is an overall benchmark of all the eye detect models:
+
+ .. image:: eye_detect_benchmark.plot.py.svg
+ :align: center
+
+"""
+from functools import lru_cache
+from typing import List, Tuple
+
+from huggingface_hub import hf_hub_download
+
+from ._yolo import _image_preprocess, _data_postprocess
+from ..data import ImageTyping, load_image, rgb_encode
+from ..utils import open_onnx_model
+
+
+@lru_cache()
+def _open_eye_detect_model(level: str = 's', version: str = 'v1.0'):
+ return open_onnx_model(hf_hub_download(
+ f'deepghs/anime_eye_detection',
+ f'eye_detect_{version}_{level}/model.onnx'
+ ))
+
+
+_LABELS = ["eye"]
+
+
+def detect_eyes(image: ImageTyping, level: str = 's', version: str = 'v1.0', max_infer_size=640,
+ conf_threshold: float = 0.3, iou_threshold: float = 0.3) \
+ -> List[Tuple[Tuple[int, int, int, int], str, float]]:
+ """
+ Overview:
+ Detect human eyes in anime images.
+
+ :param image: Image to detect.
+ :param level: The model level being used can be either `s` or `n`.
+ The `n` model runs faster with smaller system overhead, while the `s` model achieves higher accuracy.
+ The default value is `s`.
+ :param version: Version of model, default is ``v1.0``.
+ :param max_infer_size: The maximum image size used for model inference, if the image size exceeds this limit,
+ the image will be resized and used for inference. The default value is `640` pixels.
+ :param conf_threshold: The confidence threshold, only detection results with confidence scores above
+ this threshold will be returned. The default value is `0.3`.
+ :param iou_threshold: The detection area coverage overlap threshold, areas with overlaps above this threshold
+ will be discarded. The default value is `0.3`.
+ :return: The detection results list, each item includes the detected area `(x0, y0, x1, y1)`,
+ the target type (always `eye`) and the target confidence score.
+
+ Examples::
+ >>> from imgutils.detect import detect_eyes, detection_visualize
+ >>>
+ >>> image = 'squat.jpg'
+ >>> result = detect_eyes(image) # detect it
+ >>> result
+ [((297, 239, 341, 271), 'eye', 0.7760562896728516), ((230, 289, 263, 308), 'eye', 0.7682342529296875)]
+ >>>
+ >>> # visualize it
+ >>> from matplotlib import pyplot as plt
+ >>> plt.imshow(detection_visualize(image, result))
+ >>> plt.show()
+ """
+ image = load_image(image, mode='RGB')
+ new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
+
+ data = rgb_encode(new_image)[None, ...]
+ output, = _open_eye_detect_model(level, version).run(['output0'], {'images': data})
+ return _data_postprocess(output[0], conf_threshold, iou_threshold, old_size, new_size, _LABELS)
diff --git a/imgutils/detect/text.py b/imgutils/detect/text.py
new file mode 100644
index 00000000000..45518478c01
--- /dev/null
+++ b/imgutils/detect/text.py
@@ -0,0 +1,141 @@
+"""
+Overview:
+ Detect text in images.
+
+ Models are hosted on `deepghs/text_detection `_.
+
+ .. image:: text_detect_demo.plot.py.svg
+ :align: center
+
+ This is an overall benchmark of all the text detect models:
+
+ .. image:: text_detect_benchmark.plot.py.svg
+ :align: center
+
+"""
+from functools import lru_cache
+from typing import List, Tuple, Optional
+
+import cv2
+import numpy as np
+from huggingface_hub import hf_hub_download
+
+from ..data import ImageTyping, load_image
+from ..utils import open_onnx_model
+
+_DEFAULT_MODEL = 'dbnetpp_resnet50_fpnc_1200e_icdar2015'
+
+
+@lru_cache()
+def _open_text_detect_model(model: str):
+ """
+ Get an ONNX session for the specified DBNET or DBNET++ model.
+
+ This function downloads the ONNX model and opens it using the imgutils library.
+
+ :param model: Model name for DBNET or DBNET++.
+ :type model: str
+ :return: ONNX session for the specified model.
+ """
+ return open_onnx_model(hf_hub_download(
+ 'deepghs/text_detection',
+ f'{model}/end2end.onnx'
+ ))
+
+
+def _get_heatmap_of_text(image: ImageTyping, model: str) -> np.ndarray:
+ """
+ Get the heatmap of text regions from the given image using the specified model.
+
+ :param image: Input image.
+ :type image: ImageTyping
+ :param model: Model name for DBNET or DBNET++.
+ :type model: str
+ :return: Heatmap of text regions.
+ :rtype: np.ndarray
+ """
+ origin_width, origin_height = width, height = image.size
+ align = 32
+ if width % align != 0:
+ width += (align - width % align)
+ if height % align != 0:
+ height += (align - height % align)
+
+ input_ = np.array(image).transpose((2, 0, 1)).astype(np.float32) / 255.0
+ input_ = np.pad(input_[None, ...], ((0, 0), (0, 0), (0, height - origin_height), (0, width - origin_width)))
+
+ def _normalize(data, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)):
+ mean, std = np.asarray(mean), np.asarray(std)
+ return (data - mean[None, :, None, None]) / std[None, :, None, None]
+
+ ort = _open_text_detect_model(model)
+
+ input_ = _normalize(input_).astype(np.float32)
+ output_, = ort.run(['output'], {'input': input_})
+ heatmap = output_[0]
+ heatmap = heatmap[:origin_height, :origin_width]
+
+ return heatmap
+
+
+def _get_bounding_box_of_text(image: ImageTyping, model: str, threshold: float) \
+ -> List[Tuple[Tuple[int, int, int, int], float]]:
+ """
+ Get bounding boxes of detected text regions from the given image using the specified model and threshold.
+
+ :param image: Input image.
+ :type image: ImageTyping
+ :param model: Model name for DBNET or DBNET++.
+ :type model: str
+ :param threshold: Confidence threshold for text detection.
+ :type threshold: float
+ :return: List of bounding boxes and their scores.
+ :rtype: List[Tuple[Tuple[int, int, int, int], float]]
+ """
+ heatmap = _get_heatmap_of_text(image, model)
+ c_rets = cv2.findContours((heatmap * 255.0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ contours = c_rets[0] if len(c_rets) == 2 else c_rets[1]
+ bboxes = []
+ for c in contours:
+ x, y, w, h = cv2.boundingRect(c)
+ x0, y0, x1, y1 = x, y, x + w, y + h
+ score = heatmap[y0:y1, x0:x1].mean().item()
+ if score >= threshold:
+ bboxes.append(((x0, y0, x1, y1), score))
+
+ return bboxes
+
+
+def detect_text(image: ImageTyping, model: str = _DEFAULT_MODEL, threshold: float = 0.05,
+ max_area_size: Optional[int] = 640):
+ """
+ Detect text regions in the given image using the specified model and threshold.
+
+ :param image: Input image.
+ :type image: ImageTyping
+ :param model: Model name for DBNET or DBNET++.
+ :type model: str
+ :param threshold: Confidence threshold for text detection.
+ :type threshold: float
+ :param max_area_size: Max area size when doing inference. Default is ``640``, which means if
+ the image's area is over 640x640, it will be resized. When assigned to ``None``,
+ it means do not resize in any case.
+ :type max_area_size: Optional[int]
+ :return: List of detected text bounding boxes, labels, and scores.
+ :rtype: List[Tuple[Tuple[int, int, int, int], str, float]]
+ """
+ image = load_image(image)
+ if max_area_size is not None and image.width * image.height >= max_area_size ** 2:
+ r = ((image.width * image.height) / (max_area_size ** 2)) ** 0.5
+ new_width, new_height = int(image.width / r), int(image.height / r)
+ image = image.resize((new_width, new_height))
+ else:
+ r = 1.0
+
+ bboxes = []
+ for (x0, y0, x1, y1), score in _get_bounding_box_of_text(image, model, threshold):
+ x0, y0, x1, y1 = int(x0 * r), int(y0 * r), int(x1 * r), int(y1 * r)
+ bboxes.append(((x0, y0, x1, y1), 'text', score))
+
+ bboxes = sorted(bboxes, key=lambda x: x[2], reverse=True)
+ return bboxes
diff --git a/imgutils/tagging/__init__.py b/imgutils/tagging/__init__.py
index ca1f028d67b..997e7e4f596 100644
--- a/imgutils/tagging/__init__.py
+++ b/imgutils/tagging/__init__.py
@@ -11,4 +11,5 @@
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text
from .mldanbooru import get_mldanbooru_tags
+from .overlap import drop_overlap_tags, drop_overlaps_for_dict
from .wd14 import get_wd14_tags
diff --git a/imgutils/tagging/deepdanbooru.py b/imgutils/tagging/deepdanbooru.py
index 5c8c7a69029..afae2e4932b 100644
--- a/imgutils/tagging/deepdanbooru.py
+++ b/imgutils/tagging/deepdanbooru.py
@@ -16,6 +16,7 @@
from PIL import Image
from huggingface_hub import hf_hub_download
+from .overlap import drop_overlaps_for_dict
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model
@@ -31,7 +32,7 @@ def _get_deepdanbooru_labels():
general_indexes = list(np.where(df["category"] == 0)[0])
character_indexes = list(np.where(df["category"] == 4)[0])
return tag_names, tag_real_names, \
- rating_indexes, general_indexes, character_indexes
+ rating_indexes, general_indexes, character_indexes
@lru_cache()
@@ -61,7 +62,8 @@ def _image_preprocess(image: Image.Image) -> np.ndarray:
def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
- general_threshold: float = 0.5, character_threshold: float = 0.5):
+ general_threshold: float = 0.5, character_threshold: float = 0.5,
+ drop_overlap: bool = False):
"""
Overview:
Get tags for anime image based on ``deepdanbooru`` model.
@@ -73,6 +75,7 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
The default value of ``False`` indicates the use of the original tag names.
:param general_threshold: Threshold for default tags, default is ``0.35``.
:param character_threshold: Threshold for character tags, default is ``0.85``.
+ :param drop_overlap: Drop overlap tags or not, default is ``False``.
:return: Tagging results for levels, features and characters.
Example:
@@ -120,6 +123,8 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
general_names = [labels[i] for i in general_indexes]
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
+ if drop_overlap:
+ general_res = drop_overlaps_for_dict(general_res)
# Everything else is characters: pick anywhere prediction confidence > threshold
character_names = [labels[i] for i in character_indexes]
diff --git a/imgutils/tagging/mldanbooru.py b/imgutils/tagging/mldanbooru.py
index 7b1f28b81a6..de7741472fa 100644
--- a/imgutils/tagging/mldanbooru.py
+++ b/imgutils/tagging/mldanbooru.py
@@ -11,6 +11,7 @@
from PIL import Image
from huggingface_hub import hf_hub_download
+from .overlap import drop_overlaps_for_dict
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model
@@ -57,7 +58,8 @@ def _get_mldanbooru_labels(use_real_name: bool = False) -> Tuple[List[str], List
def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
- threshold: float = 0.7, size: int = 448, keep_ratio: bool = False):
+ threshold: float = 0.7, size: int = 448, keep_ratio: bool = False,
+ drop_overlap: bool = False):
"""
Overview:
Tagging image with ML-Danbooru, similar to
@@ -72,6 +74,7 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
:param size: Size when passing the resized image into model, default is ``448``.
:param keep_ratio: Keep the original ratio between height and width when passing the image into
model, default is ``False``.
+ :param drop_overlap: Drop overlap tags or not, default is ``False``.
Example:
Here are some images for example
@@ -103,4 +106,8 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
output = (1 / (1 + np.exp(-native_output))).reshape(-1)
tags = _get_mldanbooru_labels(use_real_name)
pairs = sorted([(tags[i], ratio) for i, ratio in enumerate(output)], key=lambda x: (-x[1], x[0]))
- return {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}
+
+ general_tags = {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}
+ if drop_overlap:
+ general_tags = drop_overlaps_for_dict(general_tags)
+ return general_tags
diff --git a/imgutils/tagging/overlap.py b/imgutils/tagging/overlap.py
new file mode 100644
index 00000000000..47bddc7893c
--- /dev/null
+++ b/imgutils/tagging/overlap.py
@@ -0,0 +1,113 @@
+import json
+from functools import lru_cache
+from typing import Mapping, List
+
+from huggingface_hub import hf_hub_download
+
+
+@lru_cache()
+def _get_overlap_tags() -> Mapping[str, List[str]]:
+ """
+ Retrieve the overlap tag information from the specified Hugging Face Hub repository.
+
+ This function downloads a JSON file containing tag overlap information and parses it into a dictionary.
+
+ :return: A dictionary where keys are tags and values are lists of overlapping tags.
+ :rtype: Mapping[str, List[str]]
+ """
+ json_file = hf_hub_download(
+ 'alea31415/tag_filtering',
+ 'overlap_tags.json',
+ repo_type='dataset',
+ )
+ with open(json_file, 'r') as file:
+ data = json.load(file)
+
+ return {
+ entry['query']: entry['has_overlap']
+ for entry in data if 'has_overlap' in entry and entry['has_overlap']
+ }
+
+
+def drop_overlap_tags(tags: List[str]) -> List[str]:
+ """
+ Drop overlapping tags from the given list of tags.
+
+ This function removes tags that have overlaps with other tags based on precomputed overlap information.
+
+ :param tags: A list of tags.
+ :type tags: List[str]
+ :return: A list of tags without overlaps.
+ :rtype: List[str]
+
+ Examples::
+ >>> from imgutils.tagging import drop_overlap_tags
+ >>>
+ >>> tags = [
+ ... '1girl', 'solo',
+ ... 'long_hair', 'very_long_hair', 'red_hair',
+ ... 'breasts', 'medium_breasts',
+ ... ]
+ >>> drop_overlap_tags(tags)
+ ['1girl', 'solo', 'very_long_hair', 'red_hair', 'medium_breasts']
+ """
+ overlap_tags_dict = _get_overlap_tags()
+ result_tags = []
+ tags_underscore = [tag.replace(' ', '_') for tag in tags]
+
+ for tag, tag_ in zip(tags, tags_underscore):
+
+ to_remove = False
+
+ # Case 1: If the tag is a key and some of the associated values are in tags
+ if tag_ in overlap_tags_dict:
+ overlap_values = set(val for val in overlap_tags_dict[tag_])
+ if overlap_values.intersection(set(tags_underscore)):
+ to_remove = True
+
+ # Checking superword condition separately
+ for tag_another in tags:
+ if tag in tag_another and tag != tag_another:
+ to_remove = True
+ break
+
+ if not to_remove:
+ result_tags.append(tag)
+
+ return result_tags
+
+
+def drop_overlaps_for_dict(tags: Mapping[str, float]) -> Mapping[str, float]:
+ """
+ Drop overlapping tags from the given dictionary of tags with confidence scores.
+
+ This function removes tags that have overlaps with other tags based on precomputed overlap information.
+
+ :param tags: A dictionary where keys are tags and values are confidence scores.
+ :type tags: Mapping[str, float]
+ :return: A dictionary with non-overlapping tags and their corresponding confidence scores.
+ :rtype: Mapping[str, float]
+
+ Examples::
+ >>> from imgutils.tagging import drop_overlaps_for_dict
+ >>>
+ >>> tags = {
+ ... '1girl': 0.8849405313291128,
+ ... 'solo': 0.8548297594823425,
+ ... 'long_hair': 0.03910296474461261,
+ ... 'very_long_hair': 0.6615180440330748,
+ ... 'red_hair': 0.21552028866308015,
+ ... 'breasts': 0.3165260620737027,
+ ... 'medium_breasts': 0.47744464927382957,
+ ... }
+ >>> drop_overlaps_for_dict(tags)
+ {
+ '1girl': 0.8849405313291128,
+ 'solo': 0.8548297594823425,
+ 'very_long_hair': 0.6615180440330748,
+ 'red_hair': 0.21552028866308015,
+ 'medium_breasts': 0.47744464927382957
+ }
+ """
+ key_set = set(drop_overlap_tags(list(tags.keys())))
+ return {tag: confidence for tag, confidence in tags.items() if tag in key_set}
diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py
index c24f16e1326..63165919d33 100644
--- a/imgutils/tagging/wd14.py
+++ b/imgutils/tagging/wd14.py
@@ -11,6 +11,7 @@
import numpy as np
import pandas as pd
+from .overlap import drop_overlaps_for_dict
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model
@@ -83,7 +84,8 @@ def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]:
def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
- general_threshold: float = 0.35, character_threshold: float = 0.85):
+ general_threshold: float = 0.35, character_threshold: float = 0.85,
+ drop_overlap: bool = False):
"""
Overview:
Tagging image by wd14 v2 model. Similar to
@@ -94,6 +96,7 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
``SwinV2``, ``ConvNext``, ``ConvNextV2``, ``ViT`` or ``MOAT``, default is ``ConvNextV2``.
:param general_threshold: Threshold for default tags, default is ``0.35``.
:param character_threshold: Threshold for character tags, default is ``0.85``.
+ :param drop_overlap: Drop overlap tags or not, default is ``False``.
:return: Tagging results for levels, features and characters.
Example:
@@ -148,6 +151,8 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
general_names = [labels[i] for i in general_indexes]
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
+ if drop_overlap:
+ general_res = drop_overlaps_for_dict(general_res)
# Everything else is characters: pick anywhere prediction confidence > threshold
character_names = [labels[i] for i in character_indexes]
diff --git a/imgutils/utils/onnxruntime.py b/imgutils/utils/onnxruntime.py
index 455ee892afe..22a37a11b9d 100644
--- a/imgutils/utils/onnxruntime.py
+++ b/imgutils/utils/onnxruntime.py
@@ -63,14 +63,18 @@ def get_onnx_provider(provider: Optional[str] = None):
f'but unsupported provider {provider!r} found.')
-def _open_onnx_model(ckpt: str, provider: str) -> InferenceSession:
+def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True) -> InferenceSession:
options = SessionOptions()
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
if provider == "CPUExecutionProvider":
options.intra_op_num_threads = os.cpu_count()
+ providers = [provider]
+ if use_cpu and "CPUExecutionProvider" not in providers:
+ providers.append("CPUExecutionProvider")
+
logging.info(f'Model {ckpt!r} loaded with provider {provider!r}')
- return InferenceSession(ckpt, options, [provider])
+ return InferenceSession(ckpt, options, providers=providers)
def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession:
diff --git a/test/detect/test_censor.py b/test/detect/test_censor.py
index d711c160ccb..5025ea15e10 100644
--- a/test/detect/test_censor.py
+++ b/test/detect/test_censor.py
@@ -13,7 +13,7 @@ def _release_model_after_run():
@pytest.mark.unittest
-class TestDetectHead:
+class TestDetectCensor:
def test_detect_censors(self):
detections = detect_censors(get_testfile('nude_girl.png'))
assert len(detections) == 3
diff --git a/test/detect/test_eye.py b/test/detect/test_eye.py
new file mode 100644
index 00000000000..85c88dd7424
--- /dev/null
+++ b/test/detect/test_eye.py
@@ -0,0 +1,32 @@
+import pytest
+
+from imgutils.detect.eye import _open_eye_detect_model, detect_eyes
+from test.testings import get_testfile
+
+
+@pytest.fixture(scope='module', autouse=True)
+def _release_model_after_run():
+ try:
+ yield
+ finally:
+ _open_eye_detect_model.cache_clear()
+
+
+@pytest.mark.unittest
+class TestDetectEyes:
+ def test_detect_eye(self):
+ detections = detect_eyes(get_testfile('nude_girl.png'))
+ assert len(detections) == 2
+
+ values = []
+ for bbox, label, score in detections:
+ assert label in {'eye'}
+ values.append((bbox, int(score * 1000) / 1000))
+
+ assert values == pytest.approx([
+ ((350, 160, 382, 173), 0.788),
+ ((294, 170, 319, 181), 0.756),
+ ])
+
+ def test_detect_eye_none(self):
+ assert detect_eyes(get_testfile('png_full.png')) == []
diff --git a/test/detect/test_halfbody.py b/test/detect/test_halfbody.py
index e10091ebd10..2371b80e0eb 100644
--- a/test/detect/test_halfbody.py
+++ b/test/detect/test_halfbody.py
@@ -13,7 +13,7 @@ def _release_model_after_run():
@pytest.mark.unittest
-class TestDetectHead:
+class TestDetectHalfBody:
def test_detect_halfbody(self):
detections = detect_halfbody(get_testfile('nude_girl.png'))
assert len(detections) == 1
diff --git a/test/detect/test_text.py b/test/detect/test_text.py
new file mode 100644
index 00000000000..e2092b5ecd1
--- /dev/null
+++ b/test/detect/test_text.py
@@ -0,0 +1,49 @@
+import pytest
+
+from imgutils.detect.text import _open_text_detect_model, detect_text
+from test.testings import get_testfile
+
+
+@pytest.fixture(scope='module', autouse=True)
+def _release_model_after_run():
+ try:
+ yield
+ finally:
+ _open_text_detect_model.cache_clear()
+
+
+@pytest.mark.unittest
+class TestDetectText:
+ def test_detect_text(self):
+ detections = detect_text(get_testfile('ml1.png'))
+ assert len(detections) == 4
+
+ values = []
+ for bbox, label, score in detections:
+ assert label in {'text'}
+ values.append((bbox, int(score * 1000) / 1000))
+
+ assert values == pytest.approx([
+ ((866, 45, 959, 69), 0.543),
+ ((222, 68, 313, 102), 0.543),
+ ((424, 82, 508, 113), 0.541),
+ ((691, 101, 776, 129), 0.471)
+ ])
+
+ def test_detect_text_without_resize(self):
+ detections = detect_text(get_testfile('ml2.jpg'), max_area_size=None)
+ assert len(detections) == 9
+
+ values = []
+ for bbox, label, score in detections:
+ assert label in {'text'}
+ values.append((bbox, int(score * 1000) / 1000))
+
+ assert values == pytest.approx([
+ ((360, 218, 474, 250), 0.686), ((119, 218, 203, 240), 0.653), ((392, 47, 466, 76), 0.617),
+ ((593, 174, 666, 204), 0.616), ((179, 451, 672, 472), 0.591), ((633, 314, 747, 337), 0.59),
+ ((392, 369, 517, 386), 0.589), ((621, 81, 681, 102), 0.566), ((209, 92, 281, 122), 0.423),
+ ])
+
+ def test_detect_text_none(self):
+ assert detect_text(get_testfile('png_full.png')) == []
diff --git a/test/tagging/test_deepdanbooru.py b/test/tagging/test_deepdanbooru.py
index 168448f88f8..eb53be0fc2c 100644
--- a/test/tagging/test_deepdanbooru.py
+++ b/test/tagging/test_deepdanbooru.py
@@ -5,7 +5,7 @@
from test.testings import get_testfile
-@pytest.fixture()
+@pytest.fixture(autouse=True, scope='module')
def _release_model_after_run():
try:
yield
@@ -15,7 +15,7 @@ def _release_model_after_run():
@pytest.mark.unittest
class TestTaggingDeepdanbooru:
- def test_get_deepdanbooru_tags(self, _release_model_after_run):
+ def test_get_deepdanbooru_tags(self):
rating, tags, chars = get_deepdanbooru_tags(get_testfile('6124220.jpg'))
assert rating['rating:safe'] > 0.9
assert tags['greyscale'] >= 0.8
@@ -27,3 +27,47 @@ def test_get_deepdanbooru_tags(self, _release_model_after_run):
assert tags['1girl'] >= 0.85
assert tags['ring'] > 0.8
assert chars['hu_tao_(genshin_impact)'] >= 0.7
+
+ def test_get_danbooru_tags_sample(self):
+ rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png'))
+ assert rating == pytest.approx({
+ 'rating:safe': 8.940696716308594e-06,
+ 'rating:questionable': 0.012878596782684326,
+ 'rating:explicit': 0.992286205291748,
+ }, abs=1e-3)
+ assert tags == pytest.approx({
+ '1girl': 0.9923416376113892, 'armpits': 0.9226008653640747, 'arms_behind_head': 0.5620371699333191,
+ 'arms_up': 0.7268614172935486, 'bangs': 0.7465004920959473, 'black_border': 0.9081975221633911,
+ 'blush': 0.9306209683418274, 'breasts': 0.9972158670425415,
+ 'eyebrows_visible_through_hair': 0.6717097163200378, 'hair_between_eyes': 0.7044132947921753,
+ 'hair_intakes': 0.6295598745346069, 'horns': 0.9387356042861938, 'letterboxed': 1.0,
+ 'long_hair': 0.9871174693107605, 'looking_at_viewer': 0.8953969478607178,
+ 'medium_breasts': 0.90318363904953, 'navel': 0.9425054788589478, 'nipples': 0.9989081621170044,
+ 'nude': 0.9452821016311646, 'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120401501655579,
+ 'pussy': 0.9943721294403076, 'pussy_juice': 0.8238061666488647, 'red_hair': 0.9203640222549438,
+ 'smile': 0.6659414172172546, 'solo': 0.9483305811882019, 'spread_legs': 0.7633067965507507,
+ 'stomach': 0.5396291017532349, 'sweat': 0.7880321145057678, 'thighs': 0.7451953291893005,
+ 'uncensored': 0.9594683647155762, 'very_long_hair': 0.740706205368042,
+ }, abs=1e-3)
+ assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3)
+
+ def test_get_danbooru_tags_drop_overlap(self):
+ rating, tags, chars = get_deepdanbooru_tags(get_testfile('nude_girl.png'), drop_overlap=True)
+ assert rating == pytest.approx({
+ 'rating:safe': 8.940696716308594e-06,
+ 'rating:questionable': 0.012878596782684326,
+ 'rating:explicit': 0.992286205291748,
+ }, abs=1e-3)
+ assert tags == pytest.approx({
+ '1girl': 0.9923416376113892, 'armpits': 0.9226007461547852, 'arms_behind_head': 0.5620364546775818,
+ 'arms_up': 0.7268615365028381, 'bangs': 0.7465004324913025, 'black_border': 0.9081975221633911,
+ 'blush': 0.9306209683418274, 'eyebrows_visible_through_hair': 0.6717095971107483,
+ 'hair_between_eyes': 0.7044129967689514, 'hair_intakes': 0.6295579671859741, 'horns': 0.938735842704773,
+ 'letterboxed': 1.0, 'looking_at_viewer': 0.8953973650932312, 'medium_breasts': 0.9031840562820435,
+ 'navel': 0.9425054788589478, 'nipples': 0.9989081621170044, 'nude': 0.9452821016311646,
+ 'pillarboxed': 0.9854832887649536, 'purple_eyes': 0.8120403289794922, 'pussy_juice': 0.8238056898117065,
+ 'red_hair': 0.9203639030456543, 'smile': 0.6659414172172546, 'solo': 0.948330819606781,
+ 'spread_legs': 0.7633066177368164, 'stomach': 0.5396295189857483, 'sweat': 0.7880324721336365,
+ 'thighs': 0.745195746421814, 'uncensored': 0.9594683647155762, 'very_long_hair': 0.7407056093215942
+ }, abs=1e-3)
+ assert chars == pytest.approx({'surtr_(arknights)': 0.9373699426651001}, abs=1e-3)
diff --git a/test/tagging/test_mldanbooru.py b/test/tagging/test_mldanbooru.py
index c1331aa9470..ecfcde61015 100644
--- a/test/tagging/test_mldanbooru.py
+++ b/test/tagging/test_mldanbooru.py
@@ -5,7 +5,7 @@
from test.testings import get_testfile
-@pytest.fixture()
+@pytest.fixture(autouse=True, scope='module')
def _release_model_after_run():
try:
yield
@@ -22,3 +22,42 @@ def test_get_mldanbooru_tags(self, keep_ratio):
tags = get_mldanbooru_tags(get_testfile('6125785.jpg'), keep_ratio=keep_ratio)
assert tags['1girl'] >= 0.95
+
+ def test_get_mldanbooru_tags_sample(self):
+ tags = get_mldanbooru_tags(get_testfile('nude_girl.png'))
+ assert tags == pytest.approx({
+ '1girl': 0.9999977350234985, 'breasts': 0.999940037727356, 'nipples': 0.999920129776001,
+ 'solo': 0.9993574023246765, 'pussy': 0.9993218183517456, 'horns': 0.9977452158927917,
+ 'nude': 0.995971143245697, 'purple_eyes': 0.9957809448242188, 'long_hair': 0.9929291605949402,
+ 'navel': 0.9814828038215637, 'armpits': 0.9808009266853333, 'spread_legs': 0.9767358303070068,
+ 'pussy_juice': 0.959962785243988, 'blush': 0.9482676386833191, 'uncensored': 0.9446588158607483,
+ 'looking_at_viewer': 0.9295657873153687, 'red_hair': 0.919776201248169,
+ 'medium_breasts': 0.9020175337791443, 'completely_nude': 0.8965569138526917, 'arms_up': 0.8882529139518738,
+ 'on_back': 0.8701885342597961, 'arms_behind_head': 0.8692260980606079, 'lying': 0.8653205037117004,
+ 'pillow': 0.8645844459533691, 'bangs': 0.8618668913841248, 'smile': 0.8531544804573059,
+ 'very_long_hair': 0.8332053422927856, 'pointy_ears': 0.8194612264633179, 'stomach': 0.8194073438644409,
+ 'hair_intakes': 0.8191318511962891, 'on_bed': 0.8055890202522278, 'sweat': 0.7933878302574158,
+ 'thighs': 0.7835342884063721, 'hair_between_eyes': 0.7693091630935669,
+ 'eyebrows_visible_through_hair': 0.7672545313835144, 'closed_mouth': 0.7638942003250122,
+ 'breasts_apart': 0.7527053952217102, 'bed': 0.7515304088592529, 'slit_pupils': 0.7464283108711243,
+ 'barefoot': 0.7429600954055786, 'bed_sheet': 0.7186222076416016, 'fang': 0.7162102460861206,
+ 'clitoris': 0.7013473510742188,
+ }, abs=1e-3)
+
+ def test_get_mldanbooru_tags_no_overlap(self):
+ tags = get_mldanbooru_tags(get_testfile('nude_girl.png'), drop_overlap=True)
+ assert tags == pytest.approx({
+ '1girl': 0.9999977350234985, 'nipples': 0.999920129776001, 'solo': 0.9993574023246765,
+ 'horns': 0.9977452158927917, 'purple_eyes': 0.9957809448242188, 'navel': 0.9814828038215637,
+ 'armpits': 0.9808009266853333, 'spread_legs': 0.9767358303070068, 'pussy_juice': 0.959962785243988,
+ 'blush': 0.9482676386833191, 'uncensored': 0.9446586966514587, 'looking_at_viewer': 0.9295657873153687,
+ 'red_hair': 0.9197760820388794, 'medium_breasts': 0.9020175337791443, 'completely_nude': 0.8965569138526917,
+ 'arms_up': 0.8882529139518738, 'on_back': 0.8701885342597961, 'arms_behind_head': 0.8692260980606079,
+ 'pillow': 0.8645844459533691, 'bangs': 0.8618668913841248, 'smile': 0.8531544804573059,
+ 'very_long_hair': 0.8332052230834961, 'pointy_ears': 0.8194612264633179, 'stomach': 0.8194073438644409,
+ 'hair_intakes': 0.8191318511962891, 'on_bed': 0.8055890202522278, 'sweat': 0.793387770652771,
+ 'thighs': 0.7835341691970825, 'hair_between_eyes': 0.7693091034889221,
+ 'eyebrows_visible_through_hair': 0.7672545909881592, 'closed_mouth': 0.7638942003250122,
+ 'breasts_apart': 0.7527053356170654, 'slit_pupils': 0.7464284300804138, 'barefoot': 0.7429600358009338,
+ 'bed_sheet': 0.7186222672462463, 'fang': 0.7162103652954102, 'clitoris': 0.7013473510742188
+ }, abs=1e-3)
diff --git a/test/tagging/test_overlap.py b/test/tagging/test_overlap.py
new file mode 100644
index 00000000000..6535fc7e840
--- /dev/null
+++ b/test/tagging/test_overlap.py
@@ -0,0 +1,42 @@
+import pytest
+
+from imgutils.tagging import drop_overlaps_for_dict, drop_overlap_tags
+
+
+@pytest.fixture()
+def complex_dict_tags():
+ return {
+ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507,
+ 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438,
+ 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896,
+ 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697,
+ 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648,
+ 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317,
+ 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223,
+ 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056,
+ 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686,
+ 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291,
+ 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901,
+ }
+
+
+@pytest.mark.unittest
+class TestTaggingOverlap:
+ def test_drop_overlap_tags(self):
+ assert drop_overlap_tags(['1girl', 'solo', 'long_hair', 'very_long_hair', 'red_hair']) == \
+ ['1girl', 'solo', 'very_long_hair', 'red_hair']
+
+ def test_drop_overlaps_for_dict_complex(self, complex_dict_tags):
+ assert drop_overlaps_for_dict(complex_dict_tags) == pytest.approx({
+ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698,
+ 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605,
+ 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948,
+ 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633,
+ 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317,
+ 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658,
+ 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423,
+ 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345,
+ 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291,
+ 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727,
+ 'clitoris': 0.5310801267623901
+ })
diff --git a/test/tagging/test_wd14.py b/test/tagging/test_wd14.py
index 3318795a64b..897cf9e8729 100644
--- a/test/tagging/test_wd14.py
+++ b/test/tagging/test_wd14.py
@@ -5,7 +5,7 @@
from test.testings import get_testfile
-@pytest.fixture()
+@pytest.fixture(autouse=True, scope='module')
def _release_model_after_run():
try:
yield
@@ -26,3 +26,51 @@ def test_get_wd14_tags(self):
assert 0.35 <= rating['sensitive'] <= 0.45
assert tags['1girl'] >= 0.95
assert chars['hu_tao_(genshin_impact)'] >= 0.95
+
+ def test_wd14_tags_sample(self):
+ rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'))
+ assert rating == pytest.approx({
+ 'general': 0.0020540356636047363,
+ 'sensitive': 0.0080718994140625,
+ 'questionable': 0.003170192241668701,
+ 'explicit': 0.984081506729126,
+ }, abs=1e-3)
+ assert tags == pytest.approx({
+ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507,
+ 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438,
+ 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896,
+ 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697,
+ 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648,
+ 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317,
+ 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223,
+ 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056,
+ 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686,
+ 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291,
+ 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727,
+ 'clitoris': 0.5310801267623901
+ }, abs=1e-3)
+ assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3)
+
+ def test_wd14_tags_no_overlap(self):
+ rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), drop_overlap=True)
+ # print(tags)
+ assert rating == pytest.approx({
+ 'general': 0.0020540356636047363,
+ 'sensitive': 0.0080718994140625,
+ 'questionable': 0.003170192241668701,
+ 'explicit': 0.984081506729126,
+ }, abs=1e-3)
+ assert tags == pytest.approx({
+ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698,
+ 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605,
+ 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948,
+ 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633,
+ 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317,
+ 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658,
+ 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423,
+ 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345,
+ 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291,
+ 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727,
+ 'clitoris': 0.5310801267623901
+ }, abs=1e-3)
+ assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3)
diff --git a/test/testfile/ml1.png b/test/testfile/ml1.png
new file mode 100644
index 00000000000..7fbd587cecf
Binary files /dev/null and b/test/testfile/ml1.png differ
diff --git a/test/testfile/ml2.jpg b/test/testfile/ml2.jpg
new file mode 100644
index 00000000000..4ad18461bd3
Binary files /dev/null and b/test/testfile/ml2.jpg differ