Skip to content

Commit

Permalink
dev(narugo): add pydoc for methods
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Mar 24, 2024
1 parent b742ff9 commit 045a4d5
Showing 1 changed file with 92 additions and 2 deletions.
94 changes: 92 additions & 2 deletions imgutils/metrics/dbaesthetic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Overview:
A tool for assessing the aesthetic quality of anime images using a pre-trained model.
"""
from typing import Dict, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -25,6 +30,16 @@


def _value_replace(v, mapping):
"""
Replaces values in a data structure using a mapping dictionary.
:param v: The input data structure.
:type v: Any
:param mapping: A dictionary mapping values to replacement values.
:type mapping: Dict
:return: The modified data structure.
:rtype: Any
"""
if isinstance(v, (list, tuple)):
return type(v)([_value_replace(vitem, mapping) for vitem in v])
elif isinstance(v, dict):
Expand All @@ -39,16 +54,44 @@ def _value_replace(v, mapping):


class AestheticModel:
"""
A model for assessing the aesthetic quality of anime images.
"""

def __init__(self, repo_id: str):
"""
Initializes an AestheticModel instance.
:param repo_id: The repository ID of the aesthetic assessment model.
:type repo_id: str
"""
self.repo_id = repo_id
self.classifier = ClassifyModel(repo_id)
self.cached_samples: Dict[str, Tuple] = {}

def get_aesthetic_score(self, image: ImageTyping, model_name: str) -> Tuple[float, Dict[str, float]]:
"""
Calculates the aesthetic score and confidence for an anime image.
:param image: The input anime image.
:type image: ImageTyping
:param model_name: The name of the aesthetic assessment model to use.
:type model_name: str
:return: A tuple containing the aesthetic score and confidence.
:rtype: Tuple[float, Dict[str, float]]
"""
scores = self.classifier.predict_score(image, model_name)
return sum(scores[label] * i for i, label in enumerate(_LABELS)), scores

def _get_xy_samples(self, model_name: str):
"""
Retrieves cached samples for aesthetic assessment.
:param model_name: The name of the aesthetic assessment model.
:type model_name: str
:return: Cached samples for aesthetic assessment.
:rtype: Tuple[Tuple[np.ndarray, float, float], Tuple[np.ndarray, float, float]]
"""
if model_name not in self.cached_samples:
stacked = np.load(hf_hub_download(
repo_id=self.repo_id,
Expand All @@ -59,7 +102,17 @@ def _get_xy_samples(self, model_name: str):
self.cached_samples[model_name] = ((x, x.min(), x.max()), (y, y.min(), y.max()))
return self.cached_samples[model_name]

def score_to_percentile(self, score: float, model_name: str):
def score_to_percentile(self, score: float, model_name: str) -> float:
"""
Converts an aesthetic score to a percentile rank.
:param score: The aesthetic score.
:type score: float
:param model_name: The name of the aesthetic assessment model to use.
:type model_name: str
:return: The percentile rank corresponding to the given score.
:rtype: float
"""
(x, x_min, x_max), (y, y_min, y_max) = self._get_xy_samples(model_name)
idx = np.searchsorted(x, np.clip(score, a_min=x_min, a_max=x_max))
if idx < x.shape[0] - 1:
Expand All @@ -73,7 +126,17 @@ def score_to_percentile(self, score: float, model_name: str):
return y[idx]

@classmethod
def percentile_to_label(cls, percentile: float, mapping: Optional[Dict[str, float]] = None):
def percentile_to_label(cls, percentile: float, mapping: Optional[Dict[str, float]] = None) -> str:
"""
Converts a percentile rank to an aesthetic label.
:param percentile: The percentile rank.
:type percentile: float
:param mapping: A dictionary mapping labels to percentile thresholds.
:type mapping: Optional[Dict[str, float]]
:return: The aesthetic label corresponding to the given percentile rank.
:rtype: str
"""
mapping = mapping or _DEFAULT_LABEL_MAPPING
for label, threshold in sorted(mapping.items(), key=lambda x: (-x[1], x[0])):
if percentile >= threshold:
Expand All @@ -82,6 +145,18 @@ def percentile_to_label(cls, percentile: float, mapping: Optional[Dict[str, floa
raise ValueError(f'No label for unknown percentile {percentile:.3f}.')

def get_aesthetic(self, image: ImageTyping, model_name: str, fmt=('label', 'percentile')):
"""
Analyzes the aesthetic quality of an anime image and returns the results in the specified format.
:param image: The input anime image.
:type image: ImageTyping
:param model_name: The name of the aesthetic assessment model to use.
:type model_name: str
:param fmt: The format of the output.
:type fmt: Tuple[str, ...]
:return: A dictionary containing the aesthetic assessment results.
:rtype: Dict[str, float]
"""
score, confidence = self.get_aesthetic_score(image, model_name)
percentile = self.score_to_percentile(score, model_name)
label = self.percentile_to_label(percentile)
Expand All @@ -96,6 +171,9 @@ def get_aesthetic(self, image: ImageTyping, model_name: str, fmt=('label', 'perc
)

def clear(self):
"""
Clears the internal state of the AestheticModel instance.
"""
self.classifier.clear()
self.cached_samples.clear()

Expand All @@ -105,4 +183,16 @@ def clear(self):

def anime_dbaesthetic(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME,
fmt=('label', 'percentile')):
"""
Analyzes the aesthetic quality of an anime image using a pre-trained model.
:param image: The input anime image.
:type image: ImageTyping
:param model_name: The name of the aesthetic assessment model to use. Default is _DEFAULT_MODEL_NAME.
:type model_name: str
:param fmt: The format of the output. Default is ('label', 'percentile').
:type fmt: Tuple[str, ...]
:return: A dictionary containing the aesthetic assessment results.
:rtype: Dict[str, float]
"""
return _MODEL.get_aesthetic(image, model_name, fmt)

0 comments on commit 045a4d5

Please sign in to comment.