Skip to content

Commit

Permalink
dev(narugo): add docs for nudenet
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Sep 11, 2024
1 parent e668027 commit cd95a43
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/api_doc/detect/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ imgutils.detect
halfbody
hand
head
nudenet
person
text
visual
Expand Down
14 changes: 14 additions & 0 deletions docs/source/api_doc/detect/nudenet.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
imgutils.detect.nudenet
==========================

.. currentmodule:: imgutils.detect.nudenet

.. automodule:: imgutils.detect.nudenet


detect_with_nudenet
------------------------------

.. autofunction:: detect_with_nudenet


34 changes: 34 additions & 0 deletions docs/source/api_doc/detect/nudenet_detect_benchmark.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.detect import detect_with_nudenet


class NudenetDetectBenchmark(BaseBenchmark):
def __init__(self):
BaseBenchmark.__init__(self)

def load(self):
from imgutils.detect.nudenet import _open_nudenet_yolo, _open_nudenet_nms
_ = _open_nudenet_yolo()
_ = _open_nudenet_nms()

def unload(self):
from imgutils.detect.nudenet import _open_nudenet_yolo, _open_nudenet_nms
_open_nudenet_yolo.cache_clear()
_open_nudenet_nms.cache_clear()

def run(self):
image_file = random.choice(self.all_images)
_ = detect_with_nudenet(image_file)


if __name__ == '__main__':
create_plot_cli(
[
('Nudenet', NudenetDetectBenchmark()),
],
title='Benchmark for Anime NudeNet Detections',
run_times=10,
try_times=20,
)()
19 changes: 19 additions & 0 deletions docs/source/api_doc/detect/nudenet_detect_demo.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from imgutils.detect.nudenet import _LABELS, detect_with_nudenet
from imgutils.detect.visual import detection_visualize
from plot import image_plot


def _detect(img, **kwargs):
return detection_visualize(img, detect_with_nudenet(img, **kwargs), _LABELS)


if __name__ == '__main__':
image_plot(
(_detect('nudenet/nude_girl.png'), 'simple nude'),
(_detect('nudenet/simple_sex.jpg'), 'simple sex'),
(_detect('nudenet/complex_pose.jpg'), 'complex pose'),
(_detect('nudenet/complex_sex.jpg'), 'complex sex'),
columns=2,
figsize=(9, 9),
autonudenet=False,
)
82 changes: 78 additions & 4 deletions imgutils/detect/nudenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
# NudeNet Model, from https://github.com/notAI-tech/NudeNet
# The ONNX models are hosted on https://huggingface.co/deepghs/nudenet_onnx
"""
Overview:
This module provides functionality for detecting nudity in images using the NudeNet model.
The module includes functions for preprocessing images, running the NudeNet YOLO model,
applying non-maximum suppression (NMS), and postprocessing the results. It utilizes
ONNX models hosted on `deepghs/nudenet_onnx <https://huggingface.co/deepghs/nudenet_onnx>`_
for efficient inference. The original project is
`notAI-tech/NudeNet <https://github.com/notAI-tech/NudeNet>`_.
.. collapse:: Overview of NudeNet Detect (NSFW Warning!!!)
.. image:: nudenet_detect_demo.plot.py.svg
:align: center
The main function :func:`detect_with_nudenet` can be used to perform nudity detection on
given images, returning a list of bounding boxes, labels, and confidence scores.
This is an overall benchmark of all the nudenet models:
.. image:: nudenet_detect_benchmark.plot.py.svg
:align: center
.. note::
This module requires onnxruntime version 1.18 or higher.
"""

from functools import lru_cache
from typing import Tuple, List

Expand All @@ -14,6 +40,11 @@


def _check_compatibility() -> bool:
"""
Check if the installed onnxruntime version is compatible with NudeNet.
:raises EnvironmentError: If the onnxruntime version is less than 1.18.
"""
import onnxruntime
if VersionInfo(onnxruntime.__version__) < '1.18':
raise EnvironmentError(f'Nudenet not supported on onnxruntime {onnxruntime.__version__}, '
Expand All @@ -27,6 +58,11 @@ def _check_compatibility() -> bool:

@lru_cache()
def _open_nudenet_yolo():
"""
Open and cache the NudeNet YOLO ONNX model.
:return: The loaded ONNX model for YOLO.
"""
return open_onnx_model(hf_hub_download(
repo_id=_REPO_ID,
repo_type='model',
Expand All @@ -36,15 +72,26 @@ def _open_nudenet_yolo():

@lru_cache()
def _open_nudenet_nms():
"""
Open and cache the NudeNet NMS ONNX model.
:return: The loaded ONNX model for NMS.
"""
return open_onnx_model(hf_hub_download(
repo_id=_REPO_ID,
repo_type='model',
filename='nms-yolov8.onnx',
))


def _nn_preprocessing(image: ImageTyping, model_size: int = 320) \
-> Tuple[np.ndarray, float]:
def _nn_preprocessing(image: ImageTyping, model_size: int = 320) -> Tuple[np.ndarray, float]:
"""
Preprocess the input image for the NudeNet model.
:param image: The input image.
:param model_size: The size to which the image should be resized (default: 320).
:return: A tuple containing the preprocessed image array and the scaling ratio.
"""
image = load_image(image, mode='RGB', force_background='white')
assert image.mode == 'RGB'
mat = np.array(image)
Expand All @@ -61,10 +108,25 @@ def _nn_preprocessing(image: ImageTyping, model_size: int = 320) \


def _make_np_config(topk: int = 100, iou_threshold: float = 0.45, score_threshold: float = 0.25) -> np.ndarray:
"""
Create a configuration array for the NMS model.
:param topk: The maximum number of detections to keep (default: 100).
:param iou_threshold: The IoU threshold for NMS (default: 0.45).
:param score_threshold: The score threshold for detections (default: 0.25).
:return: A numpy array containing the configuration parameters.
"""
return np.array([topk, iou_threshold, score_threshold]).astype(np.float32)


def _nn_postprocess(selected, global_ratio: float):
"""
Postprocess the model output to generate bounding boxes and labels.
:param selected: The output from the NMS model.
:param global_ratio: The scaling ratio to apply to the bounding boxes.
:return: A list of tuples, each containing a bounding box, label, and confidence score.
"""
bboxes = []
num_boxes = selected.shape[0]
for idx in range(num_boxes):
Expand Down Expand Up @@ -110,6 +172,18 @@ def _nn_postprocess(selected, global_ratio: float):
def detect_with_nudenet(image: ImageTyping, topk: int = 100,
iou_threshold: float = 0.45, score_threshold: float = 0.25) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Detect nudity in the given image using the NudeNet model.
:param image: The input image to analyze.
:param topk: The maximum number of detections to keep (default: 100).
:param iou_threshold: The IoU threshold for NMS (default: 0.45).
:param score_threshold: The score threshold for detections (default: 0.25).
:return: A list of tuples, each containing:
- A bounding box as (x1, y1, x2, y2)
- A label string
- A confidence score
"""
_check_compatibility()
input_, global_ratio = _nn_preprocessing(image, model_size=320)
config = _make_np_config(topk, iou_threshold, score_threshold)
Expand Down

0 comments on commit cd95a43

Please sign in to comment.