Skip to content

Commit

Permalink
Refactor route decorator to support multiple paths and integrate exce…
Browse files Browse the repository at this point in the history
…ption handlers with renderer
  • Loading branch information
devxpy committed Jul 27, 2024
1 parent db521a6 commit 4652096
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 120 deletions.
75 changes: 41 additions & 34 deletions gooey_ui/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,47 @@ def __init__(self, query_params: dict, status_code=303):
super().__init__(url, status_code)


def runner(
def route(app, *paths, **kwargs):
def decorator(fn):
@wraps(fn)
def wrapper(request: Request, json_data: dict | None, **kwargs):
if "request" in fn_sig.parameters:
kwargs["request"] = request
if "json_data" in fn_sig.parameters:
kwargs["json_data"] = json_data
return renderer(
partial(fn, **kwargs),
query_params=dict(request.query_params),
state=json_data and json_data.get("state"),
)

fn_sig = inspect.signature(fn)
mod_params = dict(fn_sig.parameters) | dict(
request=inspect.Parameter(
"request",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=Request,
),
json_data=inspect.Parameter(
"json_data",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Depends(_request_json),
annotation=typing.Optional[dict],
),
)
mod_sig = fn_sig.replace(parameters=list(mod_params.values()))
wrapper.__signature__ = mod_sig

for path in reversed(paths):
wrapper = app.get(path)(wrapper)
wrapper = app.post(path)(wrapper)

return wrapper

return decorator


def renderer(
fn: typing.Callable,
state: dict[str, typing.Any] = None,
query_params: dict[str, str] = None,
Expand Down Expand Up @@ -164,39 +204,6 @@ def runner(
continue


def route(fn):
@wraps(fn)
def wrapper(request: Request, json_data: dict | None, **kwargs):
if "request" in fn_sig.parameters:
kwargs["request"] = request
if "json_data" in fn_sig.parameters:
kwargs["json_data"] = json_data
return runner(
partial(fn, **kwargs),
query_params=dict(request.query_params),
state=json_data and json_data.get("state"),
)

fn_sig = inspect.signature(fn)
mod_params = dict(fn_sig.parameters) | dict(
request=inspect.Parameter(
"request",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=Request,
),
json_data=inspect.Parameter(
"json_data",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Depends(_request_json),
annotation=typing.Optional[dict],
),
)
mod_sig = fn_sig.replace(parameters=list(mod_params.values()))
wrapper.__signature__ = mod_sig

return wrapper


async def _request_json(request: Request) -> dict | None:
if request.headers.get("content-type") == "application/json":
return await request.json()
Expand Down
15 changes: 5 additions & 10 deletions routers/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
app = APIRouter()


@app.post("/payment-processing/")
@st.route
@st.route(app, "/payment-processing/")
def payment_processing_route(
request: Request, provider: str | None = None, subscription_id: str | None = None
):
Expand Down Expand Up @@ -76,8 +75,7 @@ def payment_processing_route(
)


@app.post("/account/")
@st.route
@st.route(app, "/account/")
def account_route(request: Request):
with account_page_wrapper(request, AccountTabs.billing):
billing_tab(request)
Expand All @@ -93,8 +91,7 @@ def account_route(request: Request):
)


@app.post("/account/profile/")
@st.route
@st.route(app, "/account/profile/")
def profile_route(request: Request):
with account_page_wrapper(request, AccountTabs.profile):
profile_tab(request)
Expand All @@ -110,8 +107,7 @@ def profile_route(request: Request):
)


@app.post("/saved/")
@st.route
@st.route(app, "/saved/")
def saved_route(request: Request):
with account_page_wrapper(request, AccountTabs.saved):
all_saved_runs_tab(request)
Expand All @@ -127,8 +123,7 @@ def saved_route(request: Request):
)


@app.post("/account/api-keys/")
@st.route
@st.route(app, "/account/api-keys/")
def api_keys_route(request: Request):
with account_page_wrapper(request, AccountTabs.api_keys):
api_keys_tab(request)
Expand Down
128 changes: 65 additions & 63 deletions routers/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,6 @@ async def favicon():
return FileResponse("static/favicon.ico")


@app.post("/handleError/")
@st.route
def handle_error(request: Request, json_data: dict):
context = {"request": request, "settings": settings}
match json_data["status"]:
case 404:
template = "errors/404.html"
case _:
template = "errors/unknown.html"
with page_wrapper(request):
st.html(templates.get_template(template).render(**context))


@app.get("/login/")
def login(request: Request):
if request.user and not request.user.is_anonymous:
Expand Down Expand Up @@ -208,8 +195,7 @@ def file_upload(form_data: FormData = fastapi_request_form):
return {"url": upload_file_from_bytes(filename, data, content_type)}


@app.post("/GuiComponents/")
@st.route
@st.route(app, "/GuiComponents/")
def component_page(request: Request):
import components_doc

Expand All @@ -225,8 +211,7 @@ def component_page(request: Request):
}


@app.post("/explore/")
@st.route
@st.route(app, "/explore/")
def explore_page(request: Request):
import explore

Expand All @@ -242,8 +227,7 @@ def explore_page(request: Request):
}


@app.post("/api/")
@st.route
@st.route(app, "/api/")
def api_docs_page(request: Request):
with page_wrapper(request):
_api_docs_page(request)
Expand Down Expand Up @@ -347,50 +331,60 @@ def _api_docs_page(request):
manage_api_keys(page.request.user)


@app.post("/{page_slug}/examples/")
@app.post("/{page_slug}/{run_slug}/examples/")
@app.post("/{page_slug}/{run_slug}-{example_id}/examples/")
@st.route
@st.route(
app,
"/{page_slug}/examples/",
"/{page_slug}/{run_slug}/examples/",
"/{page_slug}/{run_slug}-{example_id}/examples/",
)
def examples_route(
request: Request, page_slug: str, run_slug: str = None, example_id: str = None
):
return render_page(request, page_slug, RecipeTabs.examples, example_id)


@app.post("/{page_slug}/api/")
@app.post("/{page_slug}/{run_slug}/api/")
@app.post("/{page_slug}/{run_slug}-{example_id}/api/")
@st.route
@st.route(
app,
"/{page_slug}/api/",
"/{page_slug}/{run_slug}/api/",
"/{page_slug}/{run_slug}-{example_id}/api/",
)
def api_route(
request: Request, page_slug: str, run_slug: str = None, example_id: str = None
):
return render_page(request, page_slug, RecipeTabs.run_as_api, example_id)


@app.post("/{page_slug}/history/")
@app.post("/{page_slug}/{run_slug}/history/")
@app.post("/{page_slug}/{run_slug}-{example_id}/history/")
@st.route
@st.route(
app,
"/{page_slug}/history/",
"/{page_slug}/{run_slug}/history/",
"/{page_slug}/{run_slug}-{example_id}/history/",
)
def history_route(
request: Request, page_slug: str, run_slug: str = None, example_id: str = None
):
return render_page(request, page_slug, RecipeTabs.history, example_id)


@app.post("/{page_slug}/saved/")
@app.post("/{page_slug}/{run_slug}/saved/")
@app.post("/{page_slug}/{run_slug}-{example_id}/saved/")
@st.route
@st.route(
app,
"/{page_slug}/saved/",
"/{page_slug}/{run_slug}/saved/",
"/{page_slug}/{run_slug}-{example_id}/saved/",
)
def save_route(
request: Request, page_slug: str, run_slug: str = None, example_id: str = None
):
return render_page(request, page_slug, RecipeTabs.saved, example_id)


@app.post("/{page_slug}/integrations/add/")
@app.post("/{page_slug}/{run_slug}/integrations/add/")
@app.post("/{page_slug}/{run_slug}-{example_id}/integrations/add/")
@st.route
@st.route(
app,
"/{page_slug}/integrations/add/",
"/{page_slug}/{run_slug}/integrations/add/",
"/{page_slug}/{run_slug}-{example_id}/integrations/add/",
)
def add_integrations_route(
request: Request,
page_slug: str,
Expand All @@ -401,10 +395,12 @@ def add_integrations_route(
return render_page(request, page_slug, RecipeTabs.integrations, example_id)


@app.post("/{page_slug}/integrations/{integration_id}/stats/")
@app.post("/{page_slug}/{run_slug}/integrations/{integration_id}/stats/")
@app.post("/{page_slug}/{run_slug}-{example_id}/integrations/{integration_id}/stats/")
@st.route
@st.route(
app,
"/{page_slug}/integrations/{integration_id}/stats/",
"/{page_slug}/{run_slug}/integrations/{integration_id}/stats/",
"/{page_slug}/{run_slug}-{example_id}/integrations/{integration_id}/stats/",
)
def integrations_stats_route(
request: Request,
page_slug: str,
Expand All @@ -421,12 +417,12 @@ def integrations_stats_route(
return render_page(request, "stats", RecipeTabs.integrations, example_id)


@app.post("/{page_slug}/integrations/{integration_id}/analysis/")
@app.post("/{page_slug}/{run_slug}/integrations/{integration_id}/analysis/")
@app.post(
"/{page_slug}/{run_slug}-{example_id}/integrations/{integration_id}/analysis/"
@st.route(
app,
"/{page_slug}/integrations/{integration_id}/analysis/",
"/{page_slug}/{run_slug}/integrations/{integration_id}/analysis/",
"/{page_slug}/{run_slug}-{example_id}/integrations/{integration_id}/analysis/",
)
@st.route
def integrations_analysis_route(
request: Request,
page_slug: str,
Expand Down Expand Up @@ -457,14 +453,16 @@ def integrations_analysis_route(
)


@app.post("/{page_slug}/integrations/")
@app.post("/{page_slug}/{run_slug}/integrations/")
@app.post("/{page_slug}/{run_slug}-{example_id}/integrations/")
###
@app.post("/{page_slug}/integrations/{integration_id}/")
@app.post("/{page_slug}/{run_slug}/integrations/{integration_id}/")
@app.post("/{page_slug}/{run_slug}-{example_id}/integrations/{integration_id}/")
@st.route
@st.route(
app,
"/{page_slug}/integrations/",
"/{page_slug}/{run_slug}/integrations/",
"/{page_slug}/{run_slug}-{example_id}/integrations/",
###
"/{page_slug}/integrations/{integration_id}/",
"/{page_slug}/{run_slug}/integrations/{integration_id}/",
"/{page_slug}/{run_slug}-{example_id}/integrations/{integration_id}/",
)
def integrations_route(
request: Request,
page_slug: str,
Expand All @@ -482,9 +480,11 @@ def integrations_route(
return render_page(request, page_slug, RecipeTabs.integrations, example_id)


@app.post("/chat/")
@app.post("/chats/")
@st.route
@st.route(
app,
"/chat/",
"/chats/",
)
def chat_explore_route(request: Request):
from daras_ai_v2 import chat_explore

Expand Down Expand Up @@ -572,10 +572,12 @@ def chat_lib_route(request: Request, integration_id: str, integration_name: str
)


@app.post("/{page_slug}/")
@app.post("/{page_slug}/{run_slug}/")
@app.post("/{page_slug}/{run_slug}-{example_id}/")
@st.route
@st.route(
app,
"/{page_slug}/",
"/{page_slug}/{run_slug}/",
"/{page_slug}/{run_slug}-{example_id}/",
)
def recipe_page_or_handle(
request: Request, page_slug: str, run_slug: str = None, example_id: str = None
):
Expand Down
25 changes: 12 additions & 13 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,20 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
@app.exception_handler(HTTP_404_NOT_FOUND)
@app.exception_handler(HTTP_405_METHOD_NOT_ALLOWED)
async def not_found_exception_handler(request: Request, exc: HTTPException):
if not request.headers.get("accept", "").startswith("text/html"):
return await http_exception_handler(request, exc)
return templates.TemplateResponse(
"errors/404.html",
{"request": request, "settings": settings},
status_code=exc.status_code,
)
return await _exc_handler(request, exc, "errors/404.html")


@app.exception_handler(HTTPException)
async def server_error_exception_handler(request: Request, exc: HTTPException):
if not request.headers.get("accept", "").startswith("text/html"):
return await _exc_handler(request, exc, "errors/unknown.html")


async def _exc_handler(request: Request, exc: HTTPException, template_name: str):
if request.headers.get("accept", "").startswith("text/html"):
return templates.TemplateResponse(
template_name,
context=dict(request=request, settings=settings),
status_code=exc.status_code,
)
else:
return await http_exception_handler(request, exc)
return templates.TemplateResponse(
"errors/unknown.html",
{"request": request, "settings": settings},
status_code=exc.status_code,
)
Loading

0 comments on commit 4652096

Please sign in to comment.