Skip to content

Commit

Permalink
Merge pull request #129 from deepghs/dev/yolo
Browse files Browse the repository at this point in the history
dev(narugo): use static input size for yolo models
  • Loading branch information
narugo1992 authored Dec 9, 2024
2 parents 7df7bf5 + 9f7019b commit 4967a9b
Show file tree
Hide file tree
Showing 23 changed files with 137 additions and 131 deletions.
3 changes: 2 additions & 1 deletion imgutils/detect/booru_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@


def detect_with_booru_yolo(image: ImageTyping, model_name: str = _DEFAULT_MODEL,
conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
conf_threshold: float = 0.25, iou_threshold: float = 0.7, **kwargs) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Perform object detection on an image using the Booru YOLO model.
Expand Down Expand Up @@ -209,4 +209,5 @@ def detect_with_booru_yolo(image: ImageTyping, model_name: str = _DEFAULT_MODEL,
model_name=model_name,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
**kwargs,
)
3 changes: 2 additions & 1 deletion imgutils/detect/censor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def detect_censors(image: ImageTyping, level: str = 's', version: str = 'v1.0', model_name: Optional[str] = None,
conf_threshold: float = 0.3, iou_threshold: float = 0.7) \
conf_threshold: float = 0.3, iou_threshold: float = 0.7, **kwargs) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Detect human censor points in anime images.
Expand Down Expand Up @@ -88,4 +88,5 @@ def detect_censors(image: ImageTyping, level: str = 's', version: str = 'v1.0',
model_name=model_name or f'censor_detect_{version}_{level}',
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
**kwargs,
)
3 changes: 2 additions & 1 deletion imgutils/detect/eye.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def detect_eyes(image: ImageTyping, level: str = 's', version: str = 'v1.0', model_name: Optional[str] = None,
conf_threshold: float = 0.3, iou_threshold: float = 0.3) \
conf_threshold: float = 0.3, iou_threshold: float = 0.3, **kwargs) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Detect human eyes in anime images.
Expand Down Expand Up @@ -79,4 +79,5 @@ def detect_eyes(image: ImageTyping, level: str = 's', version: str = 'v1.0', mod
model_name=model_name or f'eye_detect_{version}_{level}',
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
**kwargs,
)
3 changes: 2 additions & 1 deletion imgutils/detect/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def detect_faces(image: ImageTyping, level: str = 's', version: str = 'v1.4', model_name: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
conf_threshold: float = 0.25, iou_threshold: float = 0.7, **kwargs) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Detect human faces in anime images using YOLOv8 models.
Expand Down Expand Up @@ -87,4 +87,5 @@ def detect_faces(image: ImageTyping, level: str = 's', version: str = 'v1.4', mo
model_name=model_name or f'face_detect_{version}_{level}',
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
**kwargs,
)
3 changes: 2 additions & 1 deletion imgutils/detect/halfbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


def detect_halfbody(image: ImageTyping, level: str = 's', version: str = 'v1.0', model_name: Optional[str] = None,
conf_threshold: float = 0.5, iou_threshold: float = 0.7) \
conf_threshold: float = 0.5, iou_threshold: float = 0.7, **kwargs) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Detect human upper-half body in anime images.
Expand Down Expand Up @@ -87,4 +87,5 @@ def detect_halfbody(image: ImageTyping, level: str = 's', version: str = 'v1.0',
model_name=model_name or f'halfbody_detect_{version}_{level}',
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
**kwargs,
)
3 changes: 2 additions & 1 deletion imgutils/detect/hand.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def detect_hands(image: ImageTyping, level: str = 's', version: str = 'v1.0', model_name: Optional[str] = None,
conf_threshold: float = 0.35, iou_threshold: float = 0.7) \
conf_threshold: float = 0.35, iou_threshold: float = 0.7, **kwargs) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Detect human hand points in anime images.
Expand Down Expand Up @@ -76,4 +76,5 @@ def detect_hands(image: ImageTyping, level: str = 's', version: str = 'v1.0', mo
model_name=model_name or f'hand_detect_{version}_{level}',
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
**kwargs,
)
3 changes: 2 additions & 1 deletion imgutils/detect/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

def detect_heads(image: ImageTyping, level: Optional[str] = None,
model_name: Optional[str] = 'head_detect_v2.0_s',
conf_threshold: float = 0.4, iou_threshold: float = 0.7) \
conf_threshold: float = 0.4, iou_threshold: float = 0.7, **kwargs) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Detect human heads in anime images using YOLOv8 models.
Expand Down Expand Up @@ -96,4 +96,5 @@ def detect_heads(image: ImageTyping, level: Optional[str] = None,
model_name=model_name or f'head_detect_v0_{level or "s"}',
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
**kwargs,
)
3 changes: 2 additions & 1 deletion imgutils/detect/person.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


def detect_person(image: ImageTyping, level: str = 'm', version: str = 'v1.1', model_name: Optional[str] = None,
conf_threshold: float = 0.3, iou_threshold: float = 0.5):
conf_threshold: float = 0.3, iou_threshold: float = 0.5, **kwargs):
"""
Detect human bodies (including the entire body) in anime images.
Expand Down Expand Up @@ -83,4 +83,5 @@ def detect_person(image: ImageTyping, level: str = 'm', version: str = 'v1.1', m
model_name=model_name or f'person_detect_{version}_{level}',
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
**kwargs,
)
49 changes: 36 additions & 13 deletions imgutils/generic/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import math
import os
from threading import Lock
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -155,7 +155,8 @@ def _yolo_nms(boxes, scores, iou_threshold: float = 0.7) -> List[int]:
return keep


def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int = 32):
def _image_preprocess(image: Image.Image, max_infer_size: Union[int, Tuple[int, int]] = 1216,
allow_dynamic: bool = False, align: int = 32):
"""
Preprocess an input image for inference.
Expand All @@ -166,6 +167,8 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
:type image: Image.Image
:param max_infer_size: Maximum size (width or height) of the processed image. Default is 1216.
:type max_infer_size: int
:param allow_dynamic: If True, allows dynamic resizing of the image while maintaining aspect ratio. Default is False.
:type allow_dynamic: bool
:param align: Value to align the image dimensions to. Default is 32.
:type align: int
Expand All @@ -183,13 +186,22 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
>>> print(old_size, new_size)
(1000, 800) (1216, 992)
"""
if isinstance(max_infer_size, int):
max_infer_width, max_infer_height = max_infer_size, max_infer_size
else:
max_infer_width, max_infer_height = max_infer_size

old_width, old_height = image.width, image.height
new_width, new_height = old_width, old_height
r = max_infer_size / max(new_width, new_height)
if r < 1:
new_width, new_height = new_width * r, new_height * r
new_width = int(math.ceil(new_width / align) * align)
new_height = int(math.ceil(new_height / align) * align)
if allow_dynamic:
r = min(max_infer_width / new_width, max_infer_height / new_height)
if r < 1:
new_width, new_height = new_width * r, new_height * r
new_width = int(math.ceil(new_width / align) * align)
new_height = int(math.ceil(new_height / align) * align)
else:
new_width, new_height = max_infer_width, max_infer_height

image = image.resize((new_width, new_height))
return image, (old_width, old_height), (new_width, new_height)

Expand Down Expand Up @@ -539,7 +551,8 @@ def _open_model(self, model_name: str):
))
model_metadata = model.get_modelmeta()
if 'imgsz' in model_metadata.custom_metadata_map:
max_infer_size = max(json.loads(model_metadata.custom_metadata_map['imgsz']))
max_infer_size = tuple(json.loads(model_metadata.custom_metadata_map['imgsz']))
assert len(max_infer_size) == 2, f'imgsz should have 2 dims, but {max_infer_size!r} found.'
else:
max_infer_size = 640
names_map = _safe_eval_names_str(model_metadata.custom_metadata_map['names'])
Expand Down Expand Up @@ -567,7 +580,8 @@ def _get_model_type(self, model_name: str):
return self._model_types[model_name]

def predict(self, image: ImageTyping, model_name: str,
conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
conf_threshold: float = 0.25, iou_threshold: float = 0.7,
allow_dynamic: bool = False) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Perform object detection on an image using the specified YOLO model.
Expand All @@ -580,6 +594,9 @@ def predict(self, image: ImageTyping, model_name: str,
:type conf_threshold: float
:param iou_threshold: IoU threshold for non-maximum suppression. Default is 0.7.
:type iou_threshold: float
:param allow_dynamic: If True, allows dynamic resizing of the image while maintaining aspect ratio.
Default is False.
:type allow_dynamic: bool
:return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
:rtype: List[Tuple[Tuple[int, int, int, int], str, float]]
Expand All @@ -594,7 +611,7 @@ def predict(self, image: ImageTyping, model_name: str,
"""
model, max_infer_size, labels = self._open_model(model_name)
image = load_image(image, mode='RGB')
new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
new_image, old_size, new_size = _image_preprocess(image, max_infer_size, allow_dynamic=allow_dynamic)
data = rgb_encode(new_image)[None, ...]
output, = model.run(['output0'], {'images': data})
model_type = self._get_model_type(model_name=model_name)
Expand Down Expand Up @@ -669,7 +686,8 @@ def make_ui(self, default_model_name: Optional[str] = None,
default_model_name = selected_model_name

def _gr_detect(image: ImageTyping, model_name: str,
iou_threshold: float = 0.7, score_threshold: float = 0.25) \
iou_threshold: float = 0.7, score_threshold: float = 0.25,
allow_dynamic: bool = False) \
-> gr.AnnotatedImage:
_, _, labels = self._open_model(model_name=model_name)
_colors = list(map(str, rnd_colors(len(labels))))
Expand All @@ -682,6 +700,7 @@ def _gr_detect(image: ImageTyping, model_name: str,
model_name=model_name,
iou_threshold=iou_threshold,
conf_threshold=score_threshold,
allow_dynamic=allow_dynamic,
)
]),
color_map=_color_map,
Expand All @@ -691,7 +710,9 @@ def _gr_detect(image: ImageTyping, model_name: str,
with gr.Row():
with gr.Column():
gr_input_image = gr.Image(type='pil', label='Original Image')
gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
with gr.Row():
gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
gr_allow_dynamic = gr.Checkbox(value=False, label='Allow Dynamic Size')
with gr.Row():
gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')
Expand All @@ -708,6 +729,7 @@ def _gr_detect(image: ImageTyping, model_name: str,
gr_model,
gr_iou_threshold,
gr_score_threshold,
gr_allow_dynamic,
],
outputs=[gr_output_image],
)
Expand Down Expand Up @@ -791,7 +813,7 @@ def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YO

def yolo_predict(image: ImageTyping, repo_id: str, model_name: str,
conf_threshold: float = 0.25, iou_threshold: float = 0.7,
hf_token: Optional[str] = None) \
hf_token: Optional[str] = None, **kwargs) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Perform object detection on an image using a YOLO model from a Hugging Face repository.
Expand Down Expand Up @@ -828,4 +850,5 @@ def yolo_predict(image: ImageTyping, repo_id: str, model_name: str,
model_name=model_name,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
**kwargs,
)
24 changes: 8 additions & 16 deletions test/detect/test_booru_yolo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from PIL import Image

from imgutils.detect import detect_with_booru_yolo
from imgutils.detect import detect_with_booru_yolo, detection_similarity
from imgutils.generic.yolo import _open_models_for_repo_id
from ..testings import get_testfile

Expand All @@ -27,28 +27,20 @@ def nude_girl_image(nude_girl_file):
@pytest.fixture()
def nude_girl_detection():
return [
((236, 1, 452, 247), 'head', 0.9584360718727112),
((211, 236, 431, 346), 'boob', 0.9300149083137512),
((62, 402, 427, 697), 'sprd', 0.8708215951919556)
((243, 0, 444, 253), 'head', 0.9584344029426575),
((213, 231, 426, 358), 'boob', 0.9308794140815735),
((86, 393, 401, 701), 'sprd', 0.8639463186264038)
]


@pytest.mark.unittest
class TestDetectBooruYOLO:
def test_detect_with_booru_yolo_file(self, nude_girl_file, nude_girl_detection):
detection = detect_with_booru_yolo(nude_girl_file)
assert [label for _, label, _ in detection] == \
[label for _, label, _ in nude_girl_detection]
for (actual_box, _, _), (expected_box, _, _) in zip(detection, nude_girl_detection):
assert actual_box == pytest.approx(expected_box)
assert [score for _, _, score in detection] == \
pytest.approx([score for _, _, score in nude_girl_detection], abs=1e-4)
similarity = detection_similarity(detection, nude_girl_detection)
assert similarity >= 0.9

def test_detect_with_booru_yolo_image(self, nude_girl_image, nude_girl_detection):
detection = detect_with_booru_yolo(nude_girl_image)
assert [label for _, label, _ in detection] == \
[label for _, label, _ in nude_girl_detection]
for (actual_box, _, _), (expected_box, _, _) in zip(detection, nude_girl_detection):
assert actual_box == pytest.approx(expected_box)
assert [score for _, _, score in detection] == \
pytest.approx([score for _, _, score in nude_girl_detection], abs=1e-4)
similarity = detection_similarity(detection, nude_girl_detection)
assert similarity >= 0.9
19 changes: 7 additions & 12 deletions test/detect/test_censor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from imgutils.detect import detection_similarity
from imgutils.detect.censor import detect_censors
from imgutils.generic.yolo import _open_models_for_repo_id
from test.testings import get_testfile
Expand All @@ -16,19 +17,13 @@ def _release_model_after_run():
@pytest.mark.unittest
class TestDetectCensor:
def test_detect_censors(self):
detections = detect_censors(get_testfile('nude_girl.png'))
assert len(detections) == 3

values = []
for bbox, label, score in detections:
assert label in {'nipple_f', 'penis', 'pussy'}
values.append((bbox, int(score * 1000) / 1000))

assert values == pytest.approx([
((365, 264, 399, 289), 0.747),
((224, 260, 252, 285), 0.683),
((206, 523, 240, 608), 0.679),
detection = detect_censors(get_testfile('nude_girl.png'))
similarity = detection_similarity(detection, [
((365, 264, 398, 289), 'nipple_f', 0.7295440435409546),
((207, 525, 237, 610), 'pussy', 0.7148708701133728),
((224, 261, 250, 287), 'nipple_f', 0.6702285408973694),
])
assert similarity >= 0.9

def test_detect_censors_none(self):
assert detect_censors(get_testfile('png_full.png')) == []
17 changes: 6 additions & 11 deletions test/detect/test_eye.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from imgutils.detect import detection_similarity
from imgutils.detect.eye import detect_eyes
from imgutils.generic.yolo import _open_models_for_repo_id
from test.testings import get_testfile
Expand All @@ -16,18 +17,12 @@ def _release_model_after_run():
@pytest.mark.unittest
class TestDetectEyes:
def test_detect_eye(self):
detections = detect_eyes(get_testfile('nude_girl.png'))
assert len(detections) == 2

values = []
for bbox, label, score in detections:
assert label in {'eye'}
values.append((bbox, int(score * 1000) / 1000))

assert values == pytest.approx([
((350, 160, 382, 173), 0.788),
((294, 170, 319, 181), 0.756),
detection = detect_eyes(get_testfile('nude_girl.png'))
similarity = detection_similarity(detection, [
((350, 159, 382, 173), 'eye', 0.7742469310760498),
((295, 169, 319, 181), 'eye', 0.7276312112808228)
])
assert similarity >= 0.9

def test_detect_eye_none(self):
assert detect_eyes(get_testfile('png_full.png')) == []
Loading

0 comments on commit 4967a9b

Please sign in to comment.