diff --git a/extra-requirements.txt b/extra-requirements.txt index a64611dbe4d49..dc8ebfb076661 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -36,6 +36,7 @@ docarray>=0.16.4: core jina-hubble-sdk>=0.30.4: core jcloud>=0.0.35: core opentelemetry-api>=1.12.0: core +gradio: core opentelemetry-instrumentation-grpc>=0.35b0: core uvloop: perf,standard,devel prometheus_client>=0.12.0: perf,standard,devel diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index fb3470425fcc4..3281650886803 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -1,5 +1,7 @@ import inspect from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +import gradio as gr +from pydantic import BaseModel from jina import DocumentArray, Document from jina._docarray import docarray_v2 @@ -189,6 +191,15 @@ async def streaming_get(request: Request): output_doc_list_model=output_doc_model, ) + def _executor_caller(*args): + return args + + print(input_doc_model, output_doc_model) + inputs = generate_gradio_interface(input_doc_model) + outputs = generate_gradio_interface(output_doc_model) + interface = gr.Interface(_executor_caller, inputs, outputs) + app = gr.mount_gradio_app(app, interface, path="/_jina_gradio_demo") + from jina.serve.runtimes.gateway.health_model import JinaHealthModel @app.get( @@ -222,3 +233,62 @@ def _doc_to_event(doc): return {'event': 'update', 'data': doc.to_dict()} else: return {'event': 'update', 'data': doc.dict()} + +def generate_gradio_interface(model: BaseModel): + """helper function to convert pydantic model to gradio interface""" + inputs = [] + + # Process each attribute in the model + for attr, field in model.__annotations__.items(): + input_type = field.__name__ + input_label = attr.replace("_", " ").capitalize() + + # Generate appropriate input component based on the field type + if input_type == "str": + # Additional options for string type + field_info = model.__annotations__[attr] + default = field_info.default if hasattr(field_info, "default") else None + choices = field_info.choices if hasattr(field_info, "choices") else None + + input_component = gr.Textbox( + label=input_label, + ) + elif input_type == "int": + # Additional options for integer type + field_info = model.__annotations__[attr] + ge = field_info.ge if hasattr(field_info, "ge") else None + le = field_info.le if hasattr(field_info, "le") else None + + input_component = gr.Number( + label=input_label, + minimum=ge, + maximum=le, + step=1, + ) + elif input_type == "float": + # Additional options for float type + field_info = model.__annotations__[attr] + ge = field_info.ge if hasattr(field_info, "ge") else None + le = field_info.le if hasattr(field_info, "le") else None + + input_component = gr.Number( + label=input_label, + minimum=ge, + maximum=le, + step=0.01, + ) + elif input_type == "bool": + # Additional options for boolean type + field_info = model.__annotations__[attr] + input_component = gr.Checkbox(label=input_label) + elif input_type == "File": + input_component = gr.File(label=input_label) + elif input_type == "Path": + input_component = gr.Textbox(label=input_label) + else: + # For unsupported types, skip the attribute + continue + + # Add the input component to the inputs list + inputs.append(input_component) + return inputs