Skip to content

Commit

Permalink
Declare exports in __all__ for type checking (#10238)
Browse files Browse the repository at this point in the history
* Declare exports

* add changeset

* type fixes

* more type fixes

* add changeset

* notebooks

* changes

---------

Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: Freddy Boulton <[email protected]>
Co-authored-by: Abubakar Abid <[email protected]>
  • Loading branch information
4 people authored Dec 23, 2024
1 parent f0cf3b7 commit 3f19210
Show file tree
Hide file tree
Showing 14 changed files with 144 additions and 23 deletions.
6 changes: 6 additions & 0 deletions .changeset/young-geckos-brake.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---

fix:Declare exports in __all__ for type checking
4 changes: 2 additions & 2 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,8 +1356,8 @@ def _upload_file(self, f: dict, data_index: int) -> dict[str, str]:
f"File {file_path} exceeds the maximum file size of {max_file_size} bytes "
f"set in {component_config.get('label', '') + ''} component."
)
with open(file_path, "rb") as f:
files = [("files", (orig_name.name, f))]
with open(file_path, "rb") as f_:
files = [("files", (orig_name.name, f_))]
r = httpx.post(
self.client.upload_url,
headers=self.client.headers,
Expand Down
2 changes: 1 addition & 1 deletion demo/agent_chatbot/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: agent_chatbot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio transformers>=4.47.0"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from dataclasses import asdict\n", "from transformers import Tool, ReactCodeAgent # type: ignore\n", "from transformers.agents import stream_to_gradio, HfApiEngine # type: ignore\n", "\n", "# Import tool from Hub\n", "image_generation_tool = Tool.from_space(\n", " space_id=\"black-forest-labs/FLUX.1-schnell\",\n", " name=\"image_generator\",\n", " description=\"Generates an image following your prompt. Returns a PIL Image.\",\n", " api_name=\"/infer\",\n", ")\n", "\n", "llm_engine = HfApiEngine(\"Qwen/Qwen2.5-Coder-32B-Instruct\")\n", "# Initialize the agent with both tools and engine\n", "agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)\n", "\n", "\n", "def interact_with_agent(prompt, history):\n", " messages = []\n", " yield messages\n", " for msg in stream_to_gradio(agent, prompt):\n", " messages.append(asdict(msg))\n", " yield messages\n", " yield messages\n", "\n", "\n", "demo = gr.ChatInterface(\n", " interact_with_agent,\n", " chatbot= gr.Chatbot(\n", " label=\"Agent\",\n", " type=\"messages\",\n", " avatar_images=(\n", " None,\n", " \"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png\",\n", " ),\n", " ),\n", " examples=[\n", " [\"Generate an image of an astronaut riding an alligator\"],\n", " [\"I am writing a children's book for my daughter. Can you help me with some illustrations?\"],\n", " ],\n", " type=\"messages\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: agent_chatbot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio transformers>=4.47.0"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from dataclasses import asdict\n", "from transformers import Tool, ReactCodeAgent # type: ignore\n", "from transformers.agents import stream_to_gradio, HfApiEngine # type: ignore\n", "\n", "# Import tool from Hub\n", "image_generation_tool = Tool.from_space( # type: ignore\n", " space_id=\"black-forest-labs/FLUX.1-schnell\",\n", " name=\"image_generator\",\n", " description=\"Generates an image following your prompt. Returns a PIL Image.\",\n", " api_name=\"/infer\",\n", ")\n", "\n", "llm_engine = HfApiEngine(\"Qwen/Qwen2.5-Coder-32B-Instruct\")\n", "# Initialize the agent with both tools and engine\n", "agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)\n", "\n", "\n", "def interact_with_agent(prompt, history):\n", " messages = []\n", " yield messages\n", " for msg in stream_to_gradio(agent, prompt):\n", " messages.append(asdict(msg)) # type: ignore\n", " yield messages\n", " yield messages\n", "\n", "\n", "demo = gr.ChatInterface(\n", " interact_with_agent,\n", " chatbot= gr.Chatbot(\n", " label=\"Agent\",\n", " type=\"messages\",\n", " avatar_images=(\n", " None,\n", " \"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png\",\n", " ),\n", " ),\n", " examples=[\n", " [\"Generate an image of an astronaut riding an alligator\"],\n", " [\"I am writing a children's book for my daughter. Can you help me with some illustrations?\"],\n", " ],\n", " type=\"messages\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
4 changes: 2 additions & 2 deletions demo/agent_chatbot/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from transformers.agents import stream_to_gradio, HfApiEngine # type: ignore

# Import tool from Hub
image_generation_tool = Tool.from_space(
image_generation_tool = Tool.from_space( # type: ignore
space_id="black-forest-labs/FLUX.1-schnell",
name="image_generator",
description="Generates an image following your prompt. Returns a PIL Image.",
Expand All @@ -20,7 +20,7 @@ def interact_with_agent(prompt, history):
messages = []
yield messages
for msg in stream_to_gradio(agent, prompt):
messages.append(asdict(msg))
messages.append(asdict(msg)) # type: ignore
yield messages
yield messages

Expand Down
118 changes: 118 additions & 0 deletions gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,121 @@
from gradio.ipython_ext import load_ipython_extension

__version__ = get_package_version()

__all__ = [
"Accordion",
"AnnotatedImage",
"Annotatedimage",
"Audio",
"BarPlot",
"Blocks",
"BrowserState",
"Brush",
"Button",
"CSVLogger",
"ChatInterface",
"ChatMessage",
"Chatbot",
"Checkbox",
"CheckboxGroup",
"Checkboxgroup",
"ClearButton",
"Code",
"ColorPicker",
"Column",
"CopyData",
"DataFrame",
"Dataframe",
"Dataset",
"DateTime",
"DeletedFileData",
"DownloadButton",
"DownloadData",
"Dropdown",
"DuplicateButton",
"EditData",
"Eraser",
"Error",
"EventData",
"Examples",
"File",
"FileData",
"FileExplorer",
"FileSize",
"Files",
"FlaggingCallback",
"Gallery",
"Group",
"HTML",
"Highlight",
"HighlightedText",
"Highlightedtext",
"IS_WASM",
"Image",
"ImageEditor",
"ImageMask",
"Info",
"Interface",
"JSON",
"Json",
"KeyUpData",
"Label",
"LikeData",
"LinePlot",
"List",
"LoginButton",
"Markdown",
"Matrix",
"MessageDict",
"Mic",
"Microphone",
"Model3D",
"MultimodalTextbox",
"NO_RELOAD",
"Number",
"Numpy",
"OAuthProfile",
"OAuthToken",
"Paint",
"ParamViewer",
"PlayableVideo",
"Plot",
"Progress",
"Radio",
"Request",
"RetryData",
"Row",
"ScatterPlot",
"SelectData",
"SimpleCSVLogger",
"Sketchpad",
"Slider",
"State",
"Tab",
"TabItem",
"TabbedInterface",
"Tabs",
"Text",
"TextArea",
"Textbox",
"Theme",
"Timer",
"UndoData",
"UploadButton",
"Video",
"Warning",
"WaveformOptions",
"__version__",
"close_all",
"deploy",
"get_package_version",
"load",
"load_chat",
"load_ipython_extension",
"mount_gradio_app",
"on",
"render",
"set_static_paths",
"skip",
"update",
]
5 changes: 1 addition & 4 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,10 +1475,7 @@ def is_callable(self, fn_index: int = 0) -> bool:
return False
if any(block.stateful for block in dependency.inputs):
return False
if any(block.stateful for block in dependency.outputs):
return False

return True
return not any(block.stateful for block in dependency.outputs)

def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
"""
Expand Down
2 changes: 1 addition & 1 deletion gradio/cli/commands/components/_docs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_param_name(param):

def format_none(value):
"""Formats None and NonType values."""
if value is None or value is type(None) or value == "None" or value == "NoneType":
if value is None or value is type(None) or value in ("None", "NoneType"):
return "None"
return value

Expand Down
8 changes: 4 additions & 4 deletions gradio/components/native_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ def __init__(
every=every,
inputs=inputs,
)
for key, val in kwargs.items():
if key == "color_legend_title":
for key_, val in kwargs.items():
if key_ == "color_legend_title":
self.color_title = val
if key in [
if key_ in [
"stroke_dash",
"overlay_point",
"x_label_angle",
Expand All @@ -161,7 +161,7 @@ def __init__(
"width",
]:
warnings.warn(
f"Argument '{key}' has been deprecated.", DeprecationWarning
f"Argument '{key_}' has been deprecated.", DeprecationWarning
)

def get_block_name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion gradio/components/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def postprocess(
"""
if self.streaming:
return value # type: ignore
if value is None or value == [None, None] or value == (None, None):
if value is None or value in ([None, None], (None, None)):
return None
if isinstance(value, (str, Path)):
processed_files = (self._format_video(value), None)
Expand Down
7 changes: 4 additions & 3 deletions gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def load_chat(
[{"role": "system", "content": system_message}] if system_message else []
)

def open_api(message: str, history: list | None) -> str:
def open_api(message: str, history: list | None) -> str | None:
history = history or start_message
if len(history) > 0 and isinstance(history[0], (list, tuple)):
history = ChatInterface._tuples_to_messages(history)
Expand All @@ -641,7 +641,8 @@ def open_api_stream(
)
response = ""
for chunk in stream:
response += chunk.choices[0].delta.content
yield response
if chunk.choices[0].delta.content is not None:
response += chunk.choices[0].delta.content
yield response

return ChatInterface(open_api_stream if streaming else open_api, type="messages")
3 changes: 1 addition & 2 deletions gradio/themes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def to_dict(self):
if (
not prop.startswith("_")
or prop.startswith("_font")
or prop == "_stylesheets"
or prop == "name"
or prop in ("_stylesheets", "name")
) and isinstance(getattr(self, prop), (list, str)):
schema["theme"][prop] = getattr(self, prop)
return schema
Expand Down
2 changes: 1 addition & 1 deletion gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,7 @@ def compare_objects(obj1, obj2, path=None):
if obj1 == obj2:
return edits

if type(obj1) != type(obj2):
if type(obj1) is not type(obj2):
edits.append(("replace", path, obj2))
return edits

Expand Down
2 changes: 1 addition & 1 deletion test/components/test_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,5 @@ def test_gallery_format(self):
output = gallery.postprocess(
[np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)]
)
if type(output.root[0]) == GalleryImage:
if isinstance(output.root[0], GalleryImage):
assert output.root[0].image.path.endswith(".jpeg")
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class GenericObject:
for x in test_objs:
hints = get_type_hints(x)
assert len(hints) == 1
assert hints["s"] == str
assert hints["s"] is str

assert len(get_type_hints(GenericObject())) == 0

Expand Down

0 comments on commit 3f19210

Please sign in to comment.