From 4572a74d5008ec0caab24bf4ec0a77a98b0a115c Mon Sep 17 00:00:00 2001 From: Dmitry Ustalov Date: Mon, 23 Dec 2024 23:33:22 +0100 Subject: [PATCH] Declare exports in __all__ for type checking (#10238) * Declare exports * add changeset * type fixes * more type fixes * add changeset * notebooks * changes --------- Co-authored-by: gradio-pr-bot Co-authored-by: Freddy Boulton Co-authored-by: Abubakar Abid --- .changeset/young-geckos-brake.md | 6 + client/python/gradio_client/client.py | 4 +- demo/agent_chatbot/run.ipynb | 2 +- demo/agent_chatbot/run.py | 4 +- gradio/__init__.py | 118 ++++++++++++++++++ gradio/blocks.py | 5 +- gradio/cli/commands/components/_docs_utils.py | 2 +- gradio/components/native_plot.py | 8 +- gradio/components/video.py | 2 +- gradio/external.py | 7 +- gradio/themes/base.py | 3 +- gradio/utils.py | 2 +- test/components/test_gallery.py | 2 +- test/test_utils.py | 2 +- 14 files changed, 144 insertions(+), 23 deletions(-) create mode 100644 .changeset/young-geckos-brake.md diff --git a/.changeset/young-geckos-brake.md b/.changeset/young-geckos-brake.md new file mode 100644 index 0000000000000..a350d60010dac --- /dev/null +++ b/.changeset/young-geckos-brake.md @@ -0,0 +1,6 @@ +--- +"gradio": patch +"gradio_client": patch +--- + +fix:Declare exports in __all__ for type checking diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 36e17a1a8cef0..9f8ca4597b844 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -1356,8 +1356,8 @@ def _upload_file(self, f: dict, data_index: int) -> dict[str, str]: f"File {file_path} exceeds the maximum file size of {max_file_size} bytes " f"set in {component_config.get('label', '') + ''} component." ) - with open(file_path, "rb") as f: - files = [("files", (orig_name.name, f))] + with open(file_path, "rb") as f_: + files = [("files", (orig_name.name, f_))] r = httpx.post( self.client.upload_url, headers=self.client.headers, diff --git a/demo/agent_chatbot/run.ipynb b/demo/agent_chatbot/run.ipynb index ccf9c2fe27edc..aec9695588921 100644 --- a/demo/agent_chatbot/run.ipynb +++ b/demo/agent_chatbot/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: agent_chatbot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio transformers>=4.47.0"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from dataclasses import asdict\n", "from transformers import Tool, ReactCodeAgent # type: ignore\n", "from transformers.agents import stream_to_gradio, HfApiEngine # type: ignore\n", "\n", "# Import tool from Hub\n", "image_generation_tool = Tool.from_space(\n", " space_id=\"black-forest-labs/FLUX.1-schnell\",\n", " name=\"image_generator\",\n", " description=\"Generates an image following your prompt. Returns a PIL Image.\",\n", " api_name=\"/infer\",\n", ")\n", "\n", "llm_engine = HfApiEngine(\"Qwen/Qwen2.5-Coder-32B-Instruct\")\n", "# Initialize the agent with both tools and engine\n", "agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)\n", "\n", "\n", "def interact_with_agent(prompt, history):\n", " messages = []\n", " yield messages\n", " for msg in stream_to_gradio(agent, prompt):\n", " messages.append(asdict(msg))\n", " yield messages\n", " yield messages\n", "\n", "\n", "demo = gr.ChatInterface(\n", " interact_with_agent,\n", " chatbot= gr.Chatbot(\n", " label=\"Agent\",\n", " type=\"messages\",\n", " avatar_images=(\n", " None,\n", " \"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png\",\n", " ),\n", " ),\n", " examples=[\n", " [\"Generate an image of an astronaut riding an alligator\"],\n", " [\"I am writing a children's book for my daughter. Can you help me with some illustrations?\"],\n", " ],\n", " type=\"messages\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: agent_chatbot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio transformers>=4.47.0"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from dataclasses import asdict\n", "from transformers import Tool, ReactCodeAgent # type: ignore\n", "from transformers.agents import stream_to_gradio, HfApiEngine # type: ignore\n", "\n", "# Import tool from Hub\n", "image_generation_tool = Tool.from_space( # type: ignore\n", " space_id=\"black-forest-labs/FLUX.1-schnell\",\n", " name=\"image_generator\",\n", " description=\"Generates an image following your prompt. Returns a PIL Image.\",\n", " api_name=\"/infer\",\n", ")\n", "\n", "llm_engine = HfApiEngine(\"Qwen/Qwen2.5-Coder-32B-Instruct\")\n", "# Initialize the agent with both tools and engine\n", "agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)\n", "\n", "\n", "def interact_with_agent(prompt, history):\n", " messages = []\n", " yield messages\n", " for msg in stream_to_gradio(agent, prompt):\n", " messages.append(asdict(msg)) # type: ignore\n", " yield messages\n", " yield messages\n", "\n", "\n", "demo = gr.ChatInterface(\n", " interact_with_agent,\n", " chatbot= gr.Chatbot(\n", " label=\"Agent\",\n", " type=\"messages\",\n", " avatar_images=(\n", " None,\n", " \"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png\",\n", " ),\n", " ),\n", " examples=[\n", " [\"Generate an image of an astronaut riding an alligator\"],\n", " [\"I am writing a children's book for my daughter. Can you help me with some illustrations?\"],\n", " ],\n", " type=\"messages\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/agent_chatbot/run.py b/demo/agent_chatbot/run.py index 4e6cf36214536..7467971e88245 100644 --- a/demo/agent_chatbot/run.py +++ b/demo/agent_chatbot/run.py @@ -4,7 +4,7 @@ from transformers.agents import stream_to_gradio, HfApiEngine # type: ignore # Import tool from Hub -image_generation_tool = Tool.from_space( +image_generation_tool = Tool.from_space( # type: ignore space_id="black-forest-labs/FLUX.1-schnell", name="image_generator", description="Generates an image following your prompt. Returns a PIL Image.", @@ -20,7 +20,7 @@ def interact_with_agent(prompt, history): messages = [] yield messages for msg in stream_to_gradio(agent, prompt): - messages.append(asdict(msg)) + messages.append(asdict(msg)) # type: ignore yield messages yield messages diff --git a/gradio/__init__.py b/gradio/__init__.py index 9375c4c512e4e..e5c0ae975313b 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -119,3 +119,121 @@ from gradio.ipython_ext import load_ipython_extension __version__ = get_package_version() + +__all__ = [ + "Accordion", + "AnnotatedImage", + "Annotatedimage", + "Audio", + "BarPlot", + "Blocks", + "BrowserState", + "Brush", + "Button", + "CSVLogger", + "ChatInterface", + "ChatMessage", + "Chatbot", + "Checkbox", + "CheckboxGroup", + "Checkboxgroup", + "ClearButton", + "Code", + "ColorPicker", + "Column", + "CopyData", + "DataFrame", + "Dataframe", + "Dataset", + "DateTime", + "DeletedFileData", + "DownloadButton", + "DownloadData", + "Dropdown", + "DuplicateButton", + "EditData", + "Eraser", + "Error", + "EventData", + "Examples", + "File", + "FileData", + "FileExplorer", + "FileSize", + "Files", + "FlaggingCallback", + "Gallery", + "Group", + "HTML", + "Highlight", + "HighlightedText", + "Highlightedtext", + "IS_WASM", + "Image", + "ImageEditor", + "ImageMask", + "Info", + "Interface", + "JSON", + "Json", + "KeyUpData", + "Label", + "LikeData", + "LinePlot", + "List", + "LoginButton", + "Markdown", + "Matrix", + "MessageDict", + "Mic", + "Microphone", + "Model3D", + "MultimodalTextbox", + "NO_RELOAD", + "Number", + "Numpy", + "OAuthProfile", + "OAuthToken", + "Paint", + "ParamViewer", + "PlayableVideo", + "Plot", + "Progress", + "Radio", + "Request", + "RetryData", + "Row", + "ScatterPlot", + "SelectData", + "SimpleCSVLogger", + "Sketchpad", + "Slider", + "State", + "Tab", + "TabItem", + "TabbedInterface", + "Tabs", + "Text", + "TextArea", + "Textbox", + "Theme", + "Timer", + "UndoData", + "UploadButton", + "Video", + "Warning", + "WaveformOptions", + "__version__", + "close_all", + "deploy", + "get_package_version", + "load", + "load_chat", + "load_ipython_extension", + "mount_gradio_app", + "on", + "render", + "set_static_paths", + "skip", + "update", +] diff --git a/gradio/blocks.py b/gradio/blocks.py index 8aabdbfb418e1..b4f227822d894 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1475,10 +1475,7 @@ def is_callable(self, fn_index: int = 0) -> bool: return False if any(block.stateful for block in dependency.inputs): return False - if any(block.stateful for block in dependency.outputs): - return False - - return True + return not any(block.stateful for block in dependency.outputs) def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None): """ diff --git a/gradio/cli/commands/components/_docs_utils.py b/gradio/cli/commands/components/_docs_utils.py index 9b61b76889149..9cf879e491df5 100644 --- a/gradio/cli/commands/components/_docs_utils.py +++ b/gradio/cli/commands/components/_docs_utils.py @@ -71,7 +71,7 @@ def get_param_name(param): def format_none(value): """Formats None and NonType values.""" - if value is None or value is type(None) or value == "None" or value == "NoneType": + if value is None or value is type(None) or value in ("None", "NoneType"): return "None" return value diff --git a/gradio/components/native_plot.py b/gradio/components/native_plot.py index 9016febdb17e8..dfe9a675d7a5e 100644 --- a/gradio/components/native_plot.py +++ b/gradio/components/native_plot.py @@ -147,10 +147,10 @@ def __init__( every=every, inputs=inputs, ) - for key, val in kwargs.items(): - if key == "color_legend_title": + for key_, val in kwargs.items(): + if key_ == "color_legend_title": self.color_title = val - if key in [ + if key_ in [ "stroke_dash", "overlay_point", "x_label_angle", @@ -161,7 +161,7 @@ def __init__( "width", ]: warnings.warn( - f"Argument '{key}' has been deprecated.", DeprecationWarning + f"Argument '{key_}' has been deprecated.", DeprecationWarning ) def get_block_name(self) -> str: diff --git a/gradio/components/video.py b/gradio/components/video.py index 83a0078c07da2..f84471e65c77c 100644 --- a/gradio/components/video.py +++ b/gradio/components/video.py @@ -276,7 +276,7 @@ def postprocess( """ if self.streaming: return value # type: ignore - if value is None or value == [None, None] or value == (None, None): + if value is None or value in ([None, None], (None, None)): return None if isinstance(value, (str, Path)): processed_files = (self._format_video(value), None) diff --git a/gradio/external.py b/gradio/external.py index d410dcdce304b..2e2d0ccf76860 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -615,7 +615,7 @@ def load_chat( [{"role": "system", "content": system_message}] if system_message else [] ) - def open_api(message: str, history: list | None) -> str: + def open_api(message: str, history: list | None) -> str | None: history = history or start_message if len(history) > 0 and isinstance(history[0], (list, tuple)): history = ChatInterface._tuples_to_messages(history) @@ -641,7 +641,8 @@ def open_api_stream( ) response = "" for chunk in stream: - response += chunk.choices[0].delta.content - yield response + if chunk.choices[0].delta.content is not None: + response += chunk.choices[0].delta.content + yield response return ChatInterface(open_api_stream if streaming else open_api, type="messages") diff --git a/gradio/themes/base.py b/gradio/themes/base.py index 51230be69623d..31d571c49d959 100644 --- a/gradio/themes/base.py +++ b/gradio/themes/base.py @@ -126,8 +126,7 @@ def to_dict(self): if ( not prop.startswith("_") or prop.startswith("_font") - or prop == "_stylesheets" - or prop == "name" + or prop in ("_stylesheets", "name") ) and isinstance(getattr(self, prop), (list, str)): schema["theme"][prop] = getattr(self, prop) return schema diff --git a/gradio/utils.py b/gradio/utils.py index 7ad4baa84da78..c6d5ba5a6aace 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -1244,7 +1244,7 @@ def compare_objects(obj1, obj2, path=None): if obj1 == obj2: return edits - if type(obj1) != type(obj2): + if type(obj1) is not type(obj2): edits.append(("replace", path, obj2)) return edits diff --git a/test/components/test_gallery.py b/test/components/test_gallery.py index 9b9ca67022461..eac9ce763a596 100644 --- a/test/components/test_gallery.py +++ b/test/components/test_gallery.py @@ -126,5 +126,5 @@ def test_gallery_format(self): output = gallery.postprocess( [np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)] ) - if type(output.root[0]) == GalleryImage: + if isinstance(output.root[0], GalleryImage): assert output.root[0].image.path.endswith(".jpeg") diff --git a/test/test_utils.py b/test/test_utils.py index 7b28c101a2a8a..7f60554bc28c6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -304,7 +304,7 @@ class GenericObject: for x in test_objs: hints = get_type_hints(x) assert len(hints) == 1 - assert hints["s"] == str + assert hints["s"] is str assert len(get_type_hints(GenericObject())) == 0