diff --git a/docs/source/api_doc/generic/classify.rst b/docs/source/api_doc/generic/classify.rst
index b3aeafd8114..d99ecd81087 100644
--- a/docs/source/api_doc/generic/classify.rst
+++ b/docs/source/api_doc/generic/classify.rst
@@ -11,7 +11,7 @@ ClassifyModel
-----------------------------------------
.. autoclass:: ClassifyModel
- :members: __init__, predict_score, predict, clear
+ :members: __init__, predict_score, predict, clear, make_ui, launch_demo
diff --git a/docs/source/api_doc/generic/yolo.rst b/docs/source/api_doc/generic/yolo.rst
index cd8882ab120..0097cfa7066 100644
--- a/docs/source/api_doc/generic/yolo.rst
+++ b/docs/source/api_doc/generic/yolo.rst
@@ -11,7 +11,7 @@ YOLOModel
-----------------------------------------
.. autoclass:: YOLOModel
- :members: __init__, predict, clear
+ :members: __init__, predict, clear, make_ui, launch_demo
diff --git a/imgutils/generic/classify.py b/imgutils/generic/classify.py
index 57f73784341..3b3a04485d4 100644
--- a/imgutils/generic/classify.py
+++ b/imgutils/generic/classify.py
@@ -23,12 +23,19 @@
import numpy as np
from PIL import Image
+from hfutils.operate import get_hf_client
+from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem
from ..data import rgb_encode, ImageTyping, load_image
from ..utils import open_onnx_model
+try:
+ import gradio as gr
+except (ImportError, ModuleNotFoundError):
+ gr = None
+
__all__ = [
'ClassifyModel',
'classify_predict_score',
@@ -36,6 +43,17 @@
]
+def _check_gradio_env():
+ """
+ Check if the Gradio library is installed and available.
+
+ :raises EnvironmentError: If Gradio is not installed.
+ """
+ if gr is None:
+ raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
+ f'Please install it with `pip install dghs-imgutils[demo]`.')
+
+
def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
"""
@@ -287,6 +305,100 @@ def clear(self):
self._models.clear()
self._labels.clear()
+ def make_ui(self, default_model_name: Optional[str] = None):
+ """
+ Create the user interface components for the classifier model demo.
+
+ This method sets up the Gradio UI components including an image input, model selection dropdown,
+ submit button, and output label. It also configures the interaction between these components.
+
+ :param default_model_name: The name of the default model to be selected in the dropdown.
+ If None, the most recently updated model will be selected.
+ :type default_model_name: Optional[str]
+
+ :raises ImportError: If Gradio is not installed or properly configured.
+
+ :Example:
+ >>> model = ClassifyModel("username/repo_name")
+ >>> model.make_ui(default_model_name="model_v1")
+ """
+
+ # demo for classifier model
+ _check_gradio_env()
+ model_list = self.model_names
+ if not default_model_name:
+ hf_client = get_hf_client(hf_token=self._get_hf_token())
+ selected_model_name, selected_time = None, None
+ for fileitem in hf_client.get_paths_info(
+ repo_id=self.repo_id,
+ repo_type='model',
+ paths=[f'{model_name}/model.onnx' for model_name in model_list],
+ expand=True,
+ ):
+ if not selected_time or fileitem.last_commit.date > selected_time:
+ selected_model_name = os.path.dirname(fileitem.path)
+ selected_time = fileitem.last_commit.date
+ default_model_name = selected_model_name
+
+ with gr.Row():
+ with gr.Column():
+ gr_input_image = gr.Image(type='pil', label='Original Image')
+ gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
+ gr_submit = gr.Button(value='Submit', variant='primary')
+
+ with gr.Column():
+ gr_output = gr.Label(label='Prediction')
+
+ gr_submit.click(
+ self.predict_score,
+ inputs=[
+ gr_input_image,
+ gr_model,
+ ],
+ outputs=[gr_output],
+ )
+
+ def launch_demo(self, default_model_name: Optional[str] = None,
+ server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
+ """
+ Launch the Gradio demo for the classifier model.
+
+ This method creates a Gradio Blocks interface, sets up the UI components using make_ui(),
+ and launches the demo server.
+
+ :param default_model_name: The name of the default model to be selected in the dropdown.
+ :type default_model_name: Optional[str]
+ :param server_name: The name of the server to run the demo on. Defaults to None.
+ :type server_name: Optional[str]
+ :param server_port: The port number to run the demo on. Defaults to None.
+ :type server_port: Optional[int]
+ :param kwargs: Additional keyword arguments to pass to the Gradio launch method.
+
+ :raises ImportError: If Gradio is not installed or properly configured.
+
+ :Example:
+ >>> model = ClassifyModel("username/repo_name")
+ >>> model.launch_demo(default_model_name="model_v1", server_name="0.0.0.0", server_port=7860)
+ """
+
+ _check_gradio_env()
+ with gr.Blocks() as demo:
+ with gr.Row():
+ with gr.Column():
+ repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
+ gr.HTML(f'
Classifier Demo For {self.repo_id}
')
+ gr.Markdown(f'This is the quick demo for classifier model [{self.repo_id}]({repo_url}). '
+ f'Powered by `dghs-imgutils`\'s quick demo module.')
+
+ with gr.Row():
+ self.make_ui(default_model_name=default_model_name)
+
+ demo.launch(
+ server_name=server_name,
+ server_port=server_port,
+ **kwargs,
+ )
+
@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel:
diff --git a/imgutils/generic/yolo.py b/imgutils/generic/yolo.py
index 3d51f5eca6f..af0c5a34479 100644
--- a/imgutils/generic/yolo.py
+++ b/imgutils/generic/yolo.py
@@ -20,18 +20,61 @@
import numpy as np
from PIL import Image
+from hbutils.color import rnd_colors
+from hfutils.operate import get_hf_client
+from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download
from ..data import load_image, rgb_encode, ImageTyping
from ..utils import open_onnx_model
+try:
+ import gradio as gr
+except (ImportError, ModuleNotFoundError):
+ gr = None
+
__all__ = [
'YOLOModel',
'yolo_predict',
]
+def _check_gradio_env():
+ """
+ Check if the Gradio library is installed and available.
+
+ :raises EnvironmentError: If Gradio is not installed.
+ """
+ if gr is None:
+ raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
+ f'Please install it with `pip install dghs-imgutils[demo]`.')
+
+
+def _v_fix(v):
+ """
+ Round and convert a float value to an integer.
+
+ :param v: The float value to be rounded and converted.
+ :type v: float
+ :return: The rounded integer value.
+ :rtype: int
+ """
+ return int(round(v))
+
+
+def _bbox_fix(bbox):
+ """
+ Fix the bounding box coordinates by rounding them to integers.
+
+ :param bbox: The bounding box coordinates.
+ :type bbox: tuple
+ :return: A tuple of fixed (rounded to integer) bounding box coordinates.
+ :rtype: tuple
+ """
+ return tuple(map(_v_fix, bbox))
+
+
def _yolo_xywh2xyxy(x: np.ndarray) -> np.ndarray:
"""
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format.
@@ -403,9 +446,146 @@ def predict(self, image: ImageTyping, model_name: str,
def clear(self):
"""
Clear cached model and metadata.
+
+ This method removes all cached models and their associated metadata from memory.
+ It's useful for freeing up memory or ensuring that the latest versions of models are loaded.
"""
self._models.clear()
+ def make_ui(self, default_model_name: Optional[str] = None,
+ default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
+ """
+ Create a Gradio-based user interface for object detection.
+
+ This method sets up an interactive UI that allows users to upload images,
+ select models, and adjust detection parameters. It uses the Gradio library
+ to create the interface.
+
+ :param default_model_name: The name of the default model to use.
+ If None, the most recently updated model is selected.
+ :type default_model_name: Optional[str]
+ :param default_conf_threshold: Default confidence threshold for the UI. Default is 0.25.
+ :type default_conf_threshold: float
+ :param default_iou_threshold: Default IoU threshold for the UI. Default is 0.7.
+ :type default_iou_threshold: float
+
+ :raises ImportError: If Gradio is not installed in the environment.
+
+ :Example:
+
+ >>> model = YOLOModel("username/repo_name")
+ >>> model.make_ui(default_model_name="yolov5s")
+ """
+ _check_gradio_env()
+ model_list = self.model_names
+ if not default_model_name:
+ hf_client = get_hf_client(hf_token=self._get_hf_token())
+ selected_model_name, selected_time = None, None
+ for fileitem in hf_client.get_paths_info(
+ repo_id=self.repo_id,
+ repo_type='model',
+ paths=[f'{model_name}/model.onnx' for model_name in model_list],
+ expand=True,
+ ):
+ if not selected_time or fileitem.last_commit.date > selected_time:
+ selected_model_name = os.path.dirname(fileitem.path)
+ selected_time = fileitem.last_commit.date
+ default_model_name = selected_model_name
+
+ def _gr_detect(image: ImageTyping, model_name: str,
+ iou_threshold: float = 0.7, score_threshold: float = 0.25) \
+ -> gr.AnnotatedImage:
+ _, _, labels = self._open_model(model_name=model_name)
+ _colors = list(map(str, rnd_colors(len(labels))))
+ _color_map = dict(zip(labels, _colors))
+ return gr.AnnotatedImage(
+ value=(image, [
+ (_bbox_fix(bbox), label)
+ for bbox, label, _ in self.predict(
+ image=image,
+ model_name=model_name,
+ iou_threshold=iou_threshold,
+ conf_threshold=score_threshold,
+ )
+ ]),
+ color_map=_color_map,
+ label='Labeled',
+ )
+
+ with gr.Row():
+ with gr.Column():
+ gr_input_image = gr.Image(type='pil', label='Original Image')
+ gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
+ with gr.Row():
+ gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
+ gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')
+
+ gr_submit = gr.Button(value='Submit', variant='primary')
+
+ with gr.Column():
+ gr_output_image = gr.AnnotatedImage(label="Labeled")
+
+ gr_submit.click(
+ _gr_detect,
+ inputs=[
+ gr_input_image,
+ gr_model,
+ gr_iou_threshold,
+ gr_score_threshold,
+ ],
+ outputs=[gr_output_image],
+ )
+
+ def launch_demo(self, default_model_name: Optional[str] = None,
+ default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7,
+ server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
+ """
+ Launch a Gradio demo for object detection.
+
+ This method creates and launches a Gradio demo that allows users to interactively
+ perform object detection on uploaded images using the YOLO model.
+
+ :param default_model_name: The name of the default model to use.
+ If None, the most recently updated model is selected.
+ :type default_model_name: Optional[str]
+ :param default_conf_threshold: Default confidence threshold for the demo. Default is 0.25.
+ :type default_conf_threshold: float
+ :param default_iou_threshold: Default IoU threshold for the demo. Default is 0.7.
+ :type default_iou_threshold: float
+ :param server_name: The name of the server to run the demo on. Default is None.
+ :type server_name: Optional[str]
+ :param server_port: The port to run the demo on. Default is None.
+ :type server_port: Optional[int]
+ :param kwargs: Additional keyword arguments to pass to gr.Blocks.launch().
+
+ :raises EnvironmentError: If Gradio is not installed in the environment.
+
+ Example:
+ >>> model = YOLOModel("username/repo_name")
+ >>> model.launch_demo(default_model_name="yolov5s", server_name="0.0.0.0", server_port=7860)
+ """
+ _check_gradio_env()
+ with gr.Blocks() as demo:
+ with gr.Row():
+ with gr.Column():
+ repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
+ gr.HTML(f'YOLO Demo For {self.repo_id}
')
+ gr.Markdown(f'This is the quick demo for YOLO model [{self.repo_id}]({repo_url}). '
+ f'Powered by `dghs-imgutils`\'s quick demo module.')
+
+ with gr.Row():
+ self.make_ui(
+ default_model_name=default_model_name,
+ default_conf_threshold=default_conf_threshold,
+ default_iou_threshold=default_iou_threshold,
+ )
+
+ demo.launch(
+ server_name=server_name,
+ server_port=server_port,
+ **kwargs,
+ )
+
@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YOLOModel:
diff --git a/requirements-demo.txt b/requirements-demo.txt
new file mode 100644
index 00000000000..f85fcdfb026
--- /dev/null
+++ b/requirements-demo.txt
@@ -0,0 +1 @@
+gradio>=4.44.0
\ No newline at end of file