From 9c994e67edb3054c63a98e49ec72a158080e5517 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 5 Nov 2024 21:24:03 +0800 Subject: [PATCH] dev(narugo): support furry model --- imgutils/validate/__init__.py | 1 + imgutils/validate/furry.py | 57 +++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 imgutils/validate/furry.py diff --git a/imgutils/validate/__init__.py b/imgutils/validate/__init__.py index 83879e7ae66..95b16135120 100644 --- a/imgutils/validate/__init__.py +++ b/imgutils/validate/__init__.py @@ -8,6 +8,7 @@ from .color import * from .completeness import * from .dbrating import * +from .furry import * from .monochrome import * from .nsfw import * from .portrait import * diff --git a/imgutils/validate/furry.py b/imgutils/validate/furry.py new file mode 100644 index 00000000000..0f861dea818 --- /dev/null +++ b/imgutils/validate/furry.py @@ -0,0 +1,57 @@ +""" +Overview: + A model for classifying anime furry images into 2 classes (``non_furry``, ``furry``). + + The following are sample images for testing. + + .. image:: furry.plot.py.svg + :align: center + + This is an overall benchmark of all the furry classification models: + + .. image:: furry_benchmark.plot.py.svg + :align: center + + The models are hosted on + `huggingface - deepghs/anime_furry `_. +""" +from typing import Tuple, Dict + +from ..data import ImageTyping +from ..generic import classify_predict, classify_predict_score + +__all__ = [ + 'anime_furry_score', + 'anime_furry', +] + +_DEFAULT_MODEL_NAME = 'mobilenetv3_v0.1_dist' +_REPO_ID = 'deepghs/anime_furry' + + +def anime_furry_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]: + """ + Get the scores for different types in a furry anime image. + + :param image: The input image. + :type image: ImageTyping + :param model_name: The model name. Default is 'mobilenetv3_v0.1_dist'. + :type model_name: str + :return: A dictionary with type scores. + :rtype: Dict[str, float] + """ + return classify_predict_score(image, _REPO_ID, model_name) + + +def anime_furry(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]: + """ + Get the primary furry type and its score. + + :param image: The input image. + :type image: ImageTyping + :param model_name: The model name. Default is 'mobilenetv3_v0.1_dist'. + :type model_name: str + :return: A tuple with the primary type and its score. + :rtype: Tuple[str, float] + """ + return classify_predict(image, _REPO_ID, model_name)