Skip to content

Commit

Permalink
dev(narugo): init version for anime style ages
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Dec 16, 2023
1 parent c595725 commit 1ff6b12
Show file tree
Hide file tree
Showing 63 changed files with 301 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/api_doc/validate/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ imgutils.validate
nsfw
portrait
rating
style_age
teen
truncate
18 changes: 18 additions & 0 deletions docs/source/api_doc/validate/style_age.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import glob
import os.path

from natsort import natsorted

from plot import image_plot

if __name__ == '__main__':
image_plot(
*natsorted(glob.glob(os.path.join('style_age', '1970s-', '*.jpg'))),
*natsorted(glob.glob(os.path.join('style_age', '1980s', '*.jpg'))),
*natsorted(glob.glob(os.path.join('style_age', '1990s', '*.jpg'))),
*natsorted(glob.glob(os.path.join('style_age', '2000s', '*.jpg'))),
*natsorted(glob.glob(os.path.join('style_age', '2010s', '*.jpg'))),
*natsorted(glob.glob(os.path.join('style_age', '2015s', '*.jpg'))),
*natsorted(glob.glob(os.path.join('style_age', '2020s', '*.jpg'))),
columns=4, figsize=(10, 26),
)
21 changes: 21 additions & 0 deletions docs/source/api_doc/validate/style_age.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
imgutils.validate.style_age
=============================================

.. currentmodule:: imgutils.validate.style_age

.. automodule:: imgutils.validate.style_age


anime_style_age_score
-----------------------------

.. autofunction:: anime_style_age_score



anime_style_age
-----------------------------

.. autofunction:: anime_style_age


Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions docs/source/api_doc/validate/style_age_benchmark.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import random

from huggingface_hub import HfFileSystem

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.validate import anime_style_age

hf_fs = HfFileSystem()

_REPOSITORY = 'deepghs/anime_style_ages'
_MODEL_NAMES = [
os.path.relpath(file, _REPOSITORY).split('/')[0] for file in
hf_fs.glob(f'{_REPOSITORY}/*/model.onnx')
]


class AnimeStyleAgeBenchmark(BaseBenchmark):
def __init__(self, model):
BaseBenchmark.__init__(self)
self.model = model

def load(self):
from imgutils.validate.style_age import _open_anime_style_age_model
_ = _open_anime_style_age_model(self.model)

def unload(self):
from imgutils.validate.style_age import _open_anime_style_age_model
_open_anime_style_age_model.cache_clear()

def run(self):
image_file = random.choice(self.all_images)
_ = anime_style_age(image_file, self.model)


if __name__ == '__main__':
create_plot_cli(
[
(name, AnimeStyleAgeBenchmark(name))
for name in _MODEL_NAMES
],
title='Benchmark for Anime Portrait Models',
run_times=10,
try_times=20,
)()
1 change: 1 addition & 0 deletions imgutils/validate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from .nsfw import *
from .portrait import *
from .rating import *
from .style_age import *
from .teen import *
from .truncate import *
178 changes: 178 additions & 0 deletions imgutils/validate/style_age.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""
Overview:
A model for classifying anime style_age images into 7 classes
(``1970s-``, ``1980s``, ``1990s``, ``2000s``, ``2010s``, ``2015s``, ``2020s``).
The following are sample images for testing.
.. image:: style_age.plot.py.svg
:align: center
This is an overall benchmark of all the style_age classification models:
.. image:: style_age_benchmark.plot.py.svg
:align: center
The models are hosted on
`huggingface - deepghs/anime_style_ages <https://huggingface.co/deepghs/anime_style_ages>`_.
"""
import json
from functools import lru_cache
from typing import Tuple, Optional, Dict, List

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from imgutils.data import rgb_encode, ImageTyping, load_image
from imgutils.utils import open_onnx_model

__all__ = [
'anime_style_age_score',
'anime_style_age',
]

_DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist'


@lru_cache()
def _open_anime_style_age_model(model_name):
"""
Open the anime style age model.
:param model_name: The model name.
:type model_name: str
:return: The ONNX model.
"""
return open_onnx_model(hf_hub_download(
f'deepghs/anime_style_ages',
f'{model_name}/model.onnx',
))


@lru_cache()
def _get_anime_style_age_labels(model_name) -> List[str]:
"""
Get the labels for the anime style age model.
:param model_name: The model name.
:type model_name: str
:return: The list of labels.
:rtype: List[str]
"""
with open(hf_hub_download(
f'deepghs/anime_style_ages',
f'{model_name}/meta.json',
), 'r') as f:
return json.load(f)['labels']


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
"""
Encode the input image.
:param image: The input image.
:type image: Image.Image
:param size: The desired size of the image.
:type size: Tuple[int, int]
:param normalize: Mean and standard deviation for normalization. Default is (0.5, 0.5).
:type normalize: Optional[Tuple[float, float]]
:return: The encoded image data.
:rtype: np.ndarray
"""
image = image.resize(size, Image.BILINEAR)
data = rgb_encode(image, order_='CHW')

if normalize is not None:
mean_, std_ = normalize
mean = np.asarray([mean_]).reshape((-1, 1, 1))
std = np.asarray([std_]).reshape((-1, 1, 1))
data = (data - mean) / std

return data.astype(np.float32)


def _raw_anime_style_age(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME):
"""
Perform raw anime style age processing on the input image.
:param image: The input image.
:type image: ImageTyping
:param model_name: The model name. Default is 'mobilenetv3_v0_dist'.
:type model_name: str
:return: The processed image data.
:rtype: np.ndarray
"""
image = load_image(image, force_background='white', mode='RGB')
input_ = _img_encode(image)[None, ...]
output, = _open_anime_style_age_model(model_name).run(['output'], {'input': input_})
return output


def anime_style_age_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]:
"""
Get the scores for different types in an anime style age.
:param image: The input image.
:type image: ImageTyping
:param model_name: The model name. Default is 'mobilenetv3_v0_dist'.
:type model_name: str
:return: A dictionary with type scores.
:rtype: Dict[str, float]
Examples::
>>> from imgutils.validate import anime_style_age_score
>>>
>>> anime_style_age_score('style_age/1970s-/1.jpg')
{'1970s-': 0.9805465340614319, '1980s': 8.761269782553427e-06, '1990s': 0.0005044879508204758, '2000s': 0.01569165475666523, '2010s': 0.002850610064342618, '2015s': 0.00037849770160391927, '2020s': 1.9434612113400362e-05}
>>> anime_style_age_score('style_age/1980s/5.jpg')
{'1970s-': 9.053497342392802e-05, '1980s': 0.9992554783821106, '1990s': 0.0006490182713605464, '2000s': 2.8857468805654207e-06, '2010s': 4.317252262353577e-07, '2015s': 6.314484721769986e-07, '2020s': 1.0750001138148946e-06}
>>> anime_style_age_score('style_age/1990s/9.jpg')
{'1970s-': 1.706833609205205e-05, '1980s': 0.00034479793976061046, '1990s': 0.9995512366294861, '2000s': 4.391363472677767e-05, '2010s': 1.4607510820496827e-05, '2015s': 2.0679690351244062e-05, '2020s': 7.661913514311891e-06}
>>> anime_style_age_score('style_age/2000s/13.jpg')
{'1970s-': 3.757471131393686e-05, '1980s': 3.0744897230761126e-05, '1990s': 2.76177470368566e-05, '2000s': 0.9996387958526611, '2010s': 9.160279296338558e-05, '2015s': 0.00013228354509919882, '2020s': 4.1361367038916796e-05}
>>> anime_style_age_score('style_age/2010s/17.jpg')
{'1970s-': 7.464057489414699e-06, '1980s': 3.2412899599876255e-05, '1990s': 5.703883653040975e-05, '2000s': 9.127358498517424e-05, '2010s': 0.9973921775817871, '2015s': 0.0022309015039354563, '2020s': 0.00018872201326303184}
>>> anime_style_age_score('style_age/2015s/21.jpg')
{'1970s-': 3.780902943617548e-06, '1980s': 1.422096920578042e-05, '1990s': 1.638929097680375e-05, '2000s': 2.152203023797483e-06, '2010s': 0.00028818511054851115, '2015s': 0.9996094107627869, '2020s': 6.58777353237383e-05}
>>> anime_style_age_score('style_age/2020s/25.jpg')
{'1970s-': 1.9200742826797068e-05, '1980s': 0.00017117452807724476, '1990s': 9.518441947875544e-05, '2000s': 2.885544381570071e-05, '2010s': 1.4389253010449465e-05, '2015s': 3.1696006772108376e-05, '2020s': 0.9996393918991089}
"""
output = _raw_anime_style_age(image, model_name)
values = dict(zip(_get_anime_style_age_labels(model_name), map(lambda x: x.item(), output[0])))
return values


def anime_style_age(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]:
"""
Get the primary anime style age type and its score.
:param image: The input image.
:type image: ImageTyping
:param model_name: The model name. Default is 'mobilenetv3_v0_dist'.
:type model_name: str
:return: A tuple with the primary type and its score.
:rtype: Tuple[str, float]
Examples::
>>> from imgutils.validate import anime_style_age
>>>
>>> anime_style_age('style_age/1970s-/1.jpg')
('1970s-', 0.9805465340614319)
>>> anime_style_age('style_age/1980s/5.jpg')
('1980s', 0.9992554783821106)
>>> anime_style_age('style_age/1990s/9.jpg')
('1990s', 0.9995512366294861)
>>> anime_style_age('style_age/2000s/13.jpg')
('2000s', 0.9996387958526611)
>>> anime_style_age('style_age/2010s/17.jpg')
('2010s', 0.9973921775817871)
>>> anime_style_age('style_age/2015s/21.jpg')
('2015s', 0.9996094107627869)
>>> anime_style_age('style_age/2020s/25.jpg')
('2020s', 0.9996393918991089)
"""
output = _raw_anime_style_age(image, model_name)[0]
max_id = np.argmax(output)
return _get_anime_style_age_labels(model_name)[max_id], output[max_id].item()
37 changes: 37 additions & 0 deletions test/validate/test_style_age.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import glob
import os.path

import pytest

from imgutils.validate import anime_style_age
from imgutils.validate.style_age import _open_anime_style_age_model, anime_style_age_score
from test.testings import get_testfile

_ROOT_DIR = get_testfile('anime_style_age')
_EXAMPLE_FILES = [
(os.path.relpath(file, _ROOT_DIR), os.path.basename(os.path.dirname(file)))
for file in glob.glob(get_testfile('anime_style_age', '**', '*.jpg'), recursive=True)
]


@pytest.fixture(scope='module', autouse=True)
def _release_model_after_run():
try:
yield
finally:
_open_anime_style_age_model.cache_clear()


@pytest.mark.unittest
class TestValidatePortrait:
@pytest.mark.parametrize(['image', 'label'], _EXAMPLE_FILES)
def test_anime_style_age(self, image, label):
image_file = get_testfile('anime_style_age', image)
tag, score = anime_style_age(image_file)
assert tag == label

@pytest.mark.parametrize(['image', 'label'], _EXAMPLE_FILES)
def test_anime_style_age_score(self, image, label):
image_file = get_testfile('anime_style_age', image)
scores = anime_style_age_score(image_file)
assert scores[label] > 0.5

0 comments on commit 1ff6b12

Please sign in to comment.