Skip to content

Commit

Permalink
dev(narugo): add gradio demo for classifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Sep 18, 2024
1 parent 31a1d23 commit 50cd291
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/source/api_doc/generic/classify.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ClassifyModel
-----------------------------------------

.. autoclass:: ClassifyModel
:members: __init__, predict_score, predict, clear
:members: __init__, predict_score, predict, clear, make_ui, launch_demo



Expand Down
73 changes: 73 additions & 0 deletions imgutils/generic/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,37 @@

import numpy as np
from PIL import Image
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem

from ..data import rgb_encode, ImageTyping, load_image
from ..utils import open_onnx_model

try:
import gradio as gr
except (ImportError, ModuleNotFoundError):
gr = None

__all__ = [
'ClassifyModel',
'classify_predict_score',
'classify_predict',
]


def _check_gradio_env():
"""
Check if the Gradio library is installed and available.
:raises EnvironmentError: If Gradio is not installed.
"""
if gr is None:
raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
f'Please install it with `pip install dghs-imgutils[demo]`.')


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
"""
Expand Down Expand Up @@ -287,6 +305,61 @@ def clear(self):
self._models.clear()
self._labels.clear()

def make_ui(self, default_model_name: Optional[str] = None):
_check_gradio_env()
model_list = self.model_names
if not default_model_name:
hf_client = get_hf_client(hf_token=self._get_hf_token())
selected_model_name, selected_time = None, None
for fileitem in hf_client.get_paths_info(
repo_id=self.repo_id,
repo_type='model',
paths=[f'{model_name}/model.onnx' for model_name in model_list],
expand=True,
):
if not selected_time or fileitem.last_commit.date > selected_time:
selected_model_name = os.path.dirname(fileitem.path)
selected_time = fileitem.last_commit.date
default_model_name = selected_model_name

with gr.Row():
with gr.Column():
gr_input_image = gr.Image(type='pil', label='Original Image')
gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
gr_submit = gr.Button(value='Submit', variant='primary')

with gr.Column():
gr_output = gr.Label(label='Prediction')

gr_submit.click(
self.predict_score,
inputs=[
gr_input_image,
gr_model,
],
outputs=[gr_output],
)

def launch_demo(self, default_model_name: Optional[str] = None,
server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
_check_gradio_env()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
gr.HTML(f'<h2 style="text-align: center;">Classifier Demo For {self.repo_id}</h2>')
gr.Markdown(f'This is the quick demo for classifier model [{self.repo_id}]({repo_url}). '
f'Powered by `dghs-imgutils`\'s quick demo module.')

with gr.Row():
self.make_ui(default_model_name=default_model_name)

demo.launch(
server_name=server_name,
server_port=server_port,
**kwargs,
)


@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel:
Expand Down

0 comments on commit 50cd291

Please sign in to comment.