Skip to content

Commit

Permalink
dev(narugo): add bangumi_char classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Dec 16, 2023
1 parent c595725 commit cb273e3
Show file tree
Hide file tree
Showing 39 changed files with 346 additions and 0 deletions.
15 changes: 15 additions & 0 deletions docs/source/api_doc/validate/bangumi_char.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import glob
import os.path

from natsort import natsorted

from plot import image_plot

if __name__ == '__main__':
image_plot(
*natsorted(glob.glob(os.path.join('bangumi_char', 'vision', '*.jpg'))),
*natsorted(glob.glob(os.path.join('bangumi_char', 'imagery', '*.jpg'))),
*natsorted(glob.glob(os.path.join('bangumi_char', 'halfbody', '*.jpg'))),
*natsorted(glob.glob(os.path.join('bangumi_char', 'face', '*.jpg'))),
columns=4, figsize=(10, 15),
)
21 changes: 21 additions & 0 deletions docs/source/api_doc/validate/bangumi_char.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
imgutils.validate.bangumi_char
=============================================

.. currentmodule:: imgutils.validate.bangumi_char

.. automodule:: imgutils.validate.bangumi_char


anime_bangumi_char_score
-----------------------------

.. autofunction:: anime_bangumi_char_score



anime_bangumi_char
-----------------------------

.. autofunction:: anime_bangumi_char


Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions docs/source/api_doc/validate/bangumi_char_benchmark.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import random

from huggingface_hub import HfFileSystem

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.validate import anime_bangumi_char

hf_fs = HfFileSystem()

_REPOSITORY = 'deepghs/bangumi_char_type'
_MODEL_NAMES = [
os.path.relpath(file, _REPOSITORY).split('/')[0] for file in
hf_fs.glob(f'{_REPOSITORY}/*/model.onnx')
]


class AnimeBangumiCharacterBenchmark(BaseBenchmark):
def __init__(self, model):
BaseBenchmark.__init__(self)
self.model = model

def load(self):
from imgutils.validate.bangumi_char import _open_anime_bangumi_char_model
_ = _open_anime_bangumi_char_model(self.model)

def unload(self):
from imgutils.validate.bangumi_char import _open_anime_bangumi_char_model
_open_anime_bangumi_char_model.cache_clear()

def run(self):
image_file = random.choice(self.all_images)
_ = anime_bangumi_char(image_file, self.model)


if __name__ == '__main__':
create_plot_cli(
[
(name, AnimeBangumiCharacterBenchmark(name))
for name in _MODEL_NAMES
],
title='Benchmark for Bangumi Character Type Models',
run_times=10,
try_times=20,
)()
1 change: 1 addition & 0 deletions docs/source/api_doc/validate/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ imgutils.validate
:maxdepth: 3

aicheck
bangumi_char
classify
color
monochrome
Expand Down
1 change: 1 addition & 0 deletions imgutils/validate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Tools for image validation and classification, which can be used to filter datasets.
"""
from .aicheck import *
from .bangumi_char import *
from .classify import *
from .color import *
from .monochrome import *
Expand Down
226 changes: 226 additions & 0 deletions imgutils/validate/bangumi_char.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""
Overview:
A model for classifying anime bangumi character images into 4 classes
(``vision``, ``imagery``, ``halfbody``, ``face``).
The following are sample images for testing.
.. image:: bangumi_character.plot.py.svg
:align: center
This is an overall benchmark of all the bangumi character classification models:
.. image:: bangumi_character_benchmark.plot.py.svg
:align: center
The models are hosted on
`huggingface - deepghs/bangumi_char_type <https://huggingface.co/deepghs/bangumi_char_type>`_.
.. note::
Please note that the classification of bangumi character types is not based on the proportion
of the head in the image but on the completeness of facial details.
The specific definitions of the four types can be found `here <https://huggingface.co/datasets/deepghs/bangumi_char_type>`_.
In anime videos, **characters in secondary positions often lack details due to simplified animation**,
leading to their classification under the ``vision`` category.
**The other three types include images with complete and clear facial features**.
If you are looking for a classification model that judges the proportion of the head in an image,
please use the :func:`imgutils.validate.anime_portrait` function.
"""
import json
from functools import lru_cache
from typing import Tuple, Optional, Dict, List

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from imgutils.data import rgb_encode, ImageTyping, load_image
from imgutils.utils import open_onnx_model

__all__ = [
'anime_bangumi_char_score',
'anime_bangumi_char',
]

_DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist'


@lru_cache()
def _open_anime_bangumi_char_model(model_name):
"""
Open the anime bangumi character model.
:param model_name: The model name.
:type model_name: str
:return: The ONNX model.
"""
return open_onnx_model(hf_hub_download(
f'deepghs/bangumi_char_type',
f'{model_name}/model.onnx',
))


@lru_cache()
def _get_anime_bangumi_char_labels(model_name) -> List[str]:
"""
Get the labels for the anime bangumi character model.
:param model_name: The model name.
:type model_name: str
:return: The list of labels.
:rtype: List[str]
"""
with open(hf_hub_download(
f'deepghs/bangumi_char_type',
f'{model_name}/meta.json',
), 'r') as f:
return json.load(f)['labels']


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
"""
Encode the input image.
:param image: The input image.
:type image: Image.Image
:param size: The desired size of the image.
:type size: Tuple[int, int]
:param normalize: Mean and standard deviation for normalization. Default is (0.5, 0.5).
:type normalize: Optional[Tuple[float, float]]
:return: The encoded image data.
:rtype: np.ndarray
"""
image = image.resize(size, Image.BILINEAR)
data = rgb_encode(image, order_='CHW')

if normalize is not None:
mean_, std_ = normalize
mean = np.asarray([mean_]).reshape((-1, 1, 1))
std = np.asarray([std_]).reshape((-1, 1, 1))
data = (data - mean) / std

return data.astype(np.float32)


def _raw_anime_bangumi_char(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME):
"""
Perform raw anime bangumi character processing on the input image.
:param image: The input image.
:type image: ImageTyping
:param model_name: The model name. Default is 'mobilenetv3_v0_dist'.
:type model_name: str
:return: The processed image data.
:rtype: np.ndarray
"""
image = load_image(image, force_background='white', mode='RGB')
input_ = _img_encode(image)[None, ...]
output, = _open_anime_bangumi_char_model(model_name).run(['output'], {'input': input_})
return output


def anime_bangumi_char_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]:
"""
Get the scores for different types in an anime bangumi character.
:param image: The input image.
:type image: ImageTyping
:param model_name: The model name. Default is 'mobilenetv3_v0_dist'.
:type model_name: str
:return: A dictionary with type scores.
:rtype: Dict[str, float]
Examples::
>>> from imgutils.validate import anime_bangumi_char_score
>>>
>>> anime_bangumi_char_score('bangumi_char/vision/1.jpg')
{'vision': 0.9998525381088257, 'imagery': 0.00012103465269319713, 'halfbody': 1.6464786313008517e-05, 'face': 9.906112609314732e-06}
>>> anime_bangumi_char_score('bangumi_char/vision/2.jpg')
{'vision': 0.9997243285179138, 'imagery': 0.0002490800397936255, 'halfbody': 1.7215803381986916e-05, 'face': 9.354368557978887e-06}
>>> anime_bangumi_char_score('bangumi_char/vision/3.jpg')
{'vision': 0.9998849630355835, 'imagery': 8.90006631379947e-05, 'halfbody': 1.3920385754317977e-05, 'face': 1.2084233276254963e-05}
>>> anime_bangumi_char_score('bangumi_char/vision/4.jpg')
{'vision': 0.9998877048492432, 'imagery': 8.732793503440917e-05, 'halfbody': 1.4264976925915107e-05, 'face': 1.0623419257171918e-05}
>>> anime_bangumi_char_score('bangumi_char/imagery/5.jpg')
{'vision': 0.07076334953308105, 'imagery': 0.9290977716445923, 'halfbody': 0.0001044218079186976, 'face': 3.4467317163944244e-05}
>>> anime_bangumi_char_score('bangumi_char/imagery/6.jpg')
{'vision': 2.2568268832401372e-05, 'imagery': 0.9999498128890991, 'halfbody': 2.1810528778587468e-05, 'face': 5.879474429093534e-06}
>>> anime_bangumi_char_score('bangumi_char/imagery/7.jpg')
{'vision': 3.260669109295122e-05, 'imagery': 0.9999510049819946, 'halfbody': 1.2321036592766177e-05, 'face': 4.025227553938748e-06}
>>> anime_bangumi_char_score('bangumi_char/imagery/8.jpg')
{'vision': 1.4251427273848094e-05, 'imagery': 0.999957799911499, 'halfbody': 2.4273678718600422e-05, 'face': 3.6884023302263813e-06}
>>> anime_bangumi_char_score('bangumi_char/halfbody/9.jpg')
{'vision': 3.880981603288092e-05, 'imagery': 0.0002326338435523212, 'halfbody': 0.9996368885040283, 'face': 9.164971561403945e-05}
>>> anime_bangumi_char_score('bangumi_char/halfbody/10.jpg')
{'vision': 0.00020793956355191767, 'imagery': 0.13438372313976288, 'halfbody': 0.8652494549751282, 'face': 0.000158855298650451}
>>> anime_bangumi_char_score('bangumi_char/halfbody/11.jpg')
{'vision': 0.000238816806813702, 'imagery': 0.3589179217815399, 'halfbody': 0.6406960487365723, 'face': 0.0001471740542910993}
>>> anime_bangumi_char_score('bangumi_char/halfbody/12.jpg')
{'vision': 0.002255884697660804, 'imagery': 0.08208147436380386, 'halfbody': 0.9152728915214539, 'face': 0.00038967153523117304}
>>> anime_bangumi_char_score('bangumi_char/face/13.jpg')
{'vision': 9.227699592884164e-06, 'imagery': 1.0835404282261152e-05, 'halfbody': 5.1437502406770363e-05, 'face': 0.9999284744262695}
>>> anime_bangumi_char_score('bangumi_char/face/14.jpg')
{'vision': 1.2125529792683665e-05, 'imagery': 1.0218892384727951e-05, 'halfbody': 0.00011914174683624879, 'face': 0.9998584985733032}
>>> anime_bangumi_char_score('bangumi_char/face/15.jpg')
{'vision': 1.2007669283775613e-05, 'imagery': 1.6357082131435163e-05, 'halfbody': 5.3068713896209374e-05, 'face': 0.9999185800552368}
>>> anime_bangumi_char_score('bangumi_char/face/16.jpg')
{'vision': 1.066640925273532e-05, 'imagery': 9.529400813335087e-06, 'halfbody': 4.089402500540018e-05, 'face': 0.9999388456344604}
"""
output = _raw_anime_bangumi_char(image, model_name)
values = dict(zip(_get_anime_bangumi_char_labels(model_name), map(lambda x: x.item(), output[0])))
return values


def anime_bangumi_char(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]:
"""
Get the primary anime bangumi character type and its score.
:param image: The input image.
:type image: ImageTyping
:param model_name: The model name. Default is 'mobilenetv3_v0_dist'.
:type model_name: str
:return: A tuple with the primary type and its score.
:rtype: Tuple[str, float]
Examples::
>>> from imgutils.validate import anime_bangumi_char
>>>
>>> anime_bangumi_char('bangumi_char/vision/1.jpg')
('vision', 0.9998525381088257)
>>> anime_bangumi_char('bangumi_char/vision/2.jpg')
('vision', 0.9997243285179138)
>>> anime_bangumi_char('bangumi_char/vision/3.jpg')
('vision', 0.9998849630355835)
>>> anime_bangumi_char('bangumi_char/vision/4.jpg')
('vision', 0.9998877048492432)
>>> anime_bangumi_char('bangumi_char/imagery/5.jpg')
('imagery', 0.9290977716445923)
>>> anime_bangumi_char('bangumi_char/imagery/6.jpg')
('imagery', 0.9999498128890991)
>>> anime_bangumi_char('bangumi_char/imagery/7.jpg')
('imagery', 0.9999510049819946)
>>> anime_bangumi_char('bangumi_char/imagery/8.jpg')
('imagery', 0.999957799911499)
>>> anime_bangumi_char('bangumi_char/halfbody/9.jpg')
('halfbody', 0.9996368885040283)
>>> anime_bangumi_char('bangumi_char/halfbody/10.jpg')
('halfbody', 0.8652494549751282)
>>> anime_bangumi_char('bangumi_char/halfbody/11.jpg')
('halfbody', 0.6406959295272827)
>>> anime_bangumi_char('bangumi_char/halfbody/12.jpg')
('halfbody', 0.9152728915214539)
>>> anime_bangumi_char('bangumi_char/face/13.jpg')
('face', 0.9999284744262695)
>>> anime_bangumi_char('bangumi_char/face/14.jpg')
('face', 0.9998584985733032)
>>> anime_bangumi_char('bangumi_char/face/15.jpg')
('face', 0.9999185800552368)
>>> anime_bangumi_char('bangumi_char/face/16.jpg')
('face', 0.9999388456344604)
"""
output = _raw_anime_bangumi_char(image, model_name)[0]
max_id = np.argmax(output)
return _get_anime_bangumi_char_labels(model_name)[max_id], output[max_id].item()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 37 additions & 0 deletions test/validate/test_bangumi_char.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import glob
import os.path

import pytest

from imgutils.validate import anime_bangumi_char
from imgutils.validate.bangumi_char import _open_anime_bangumi_char_model, anime_bangumi_char_score
from test.testings import get_testfile

_ROOT_DIR = get_testfile('bangumi_char')
_EXAMPLE_FILES = [
(os.path.relpath(file, _ROOT_DIR), os.path.basename(os.path.dirname(file)))
for file in glob.glob(get_testfile('bangumi_char', '**', '*.jpg'), recursive=True)
]


@pytest.fixture(scope='module', autouse=True)
def _release_model_after_run():
try:
yield
finally:
_open_anime_bangumi_char_model.cache_clear()


@pytest.mark.unittest
class TestValidateBangumiChar:
@pytest.mark.parametrize(['image', 'label'], _EXAMPLE_FILES)
def test_anime_bangumi_char(self, image, label):
image_file = get_testfile('bangumi_char', image)
tag, score = anime_bangumi_char(image_file)
assert tag == label

@pytest.mark.parametrize(['image', 'label'], _EXAMPLE_FILES)
def test_anime_bangumi_char_score(self, image, label):
image_file = get_testfile('bangumi_char', image)
scores = anime_bangumi_char_score(image_file)
assert scores[label] > 0.5

0 comments on commit cb273e3

Please sign in to comment.