From 50cd29173b7e81b8c50cff512d64717cf8912a68 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 18 Sep 2024 19:37:46 +0800 Subject: [PATCH] dev(narugo): add gradio demo for classifiers --- docs/source/api_doc/generic/classify.rst | 2 +- imgutils/generic/classify.py | 73 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/docs/source/api_doc/generic/classify.rst b/docs/source/api_doc/generic/classify.rst index b3aeafd8114..d99ecd81087 100644 --- a/docs/source/api_doc/generic/classify.rst +++ b/docs/source/api_doc/generic/classify.rst @@ -11,7 +11,7 @@ ClassifyModel ----------------------------------------- .. autoclass:: ClassifyModel - :members: __init__, predict_score, predict, clear + :members: __init__, predict_score, predict, clear, make_ui, launch_demo diff --git a/imgutils/generic/classify.py b/imgutils/generic/classify.py index 57f73784341..432c2cc2253 100644 --- a/imgutils/generic/classify.py +++ b/imgutils/generic/classify.py @@ -23,12 +23,19 @@ import numpy as np from PIL import Image +from hfutils.operate import get_hf_client +from hfutils.repository import hf_hub_repo_url from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import hf_hub_download, HfFileSystem from ..data import rgb_encode, ImageTyping, load_image from ..utils import open_onnx_model +try: + import gradio as gr +except (ImportError, ModuleNotFoundError): + gr = None + __all__ = [ 'ClassifyModel', 'classify_predict_score', @@ -36,6 +43,17 @@ ] +def _check_gradio_env(): + """ + Check if the Gradio library is installed and available. + + :raises EnvironmentError: If Gradio is not installed. + """ + if gr is None: + raise EnvironmentError(f'Gradio required for launching webui-based demo.\n' + f'Please install it with `pip install dghs-imgutils[demo]`.') + + def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): """ @@ -287,6 +305,61 @@ def clear(self): self._models.clear() self._labels.clear() + def make_ui(self, default_model_name: Optional[str] = None): + _check_gradio_env() + model_list = self.model_names + if not default_model_name: + hf_client = get_hf_client(hf_token=self._get_hf_token()) + selected_model_name, selected_time = None, None + for fileitem in hf_client.get_paths_info( + repo_id=self.repo_id, + repo_type='model', + paths=[f'{model_name}/model.onnx' for model_name in model_list], + expand=True, + ): + if not selected_time or fileitem.last_commit.date > selected_time: + selected_model_name = os.path.dirname(fileitem.path) + selected_time = fileitem.last_commit.date + default_model_name = selected_model_name + + with gr.Row(): + with gr.Column(): + gr_input_image = gr.Image(type='pil', label='Original Image') + gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') + gr_submit = gr.Button(value='Submit', variant='primary') + + with gr.Column(): + gr_output = gr.Label(label='Prediction') + + gr_submit.click( + self.predict_score, + inputs=[ + gr_input_image, + gr_model, + ], + outputs=[gr_output], + ) + + def launch_demo(self, default_model_name: Optional[str] = None, + server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): + _check_gradio_env() + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model') + gr.HTML(f'

Classifier Demo For {self.repo_id}

') + gr.Markdown(f'This is the quick demo for classifier model [{self.repo_id}]({repo_url}). ' + f'Powered by `dghs-imgutils`\'s quick demo module.') + + with gr.Row(): + self.make_ui(default_model_name=default_model_name) + + demo.launch( + server_name=server_name, + server_port=server_port, + **kwargs, + ) + @lru_cache() def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel: