From 88dd5b540a9417d6c4b9012e975d36bd29eebc2e Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 2 Aug 2024 16:16:45 +0530 Subject: [PATCH] Add method-name and group-name to openapi schema --- daras_ai_v2/base.py | 7 +++++++ routers/api.py | 10 +++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index a90f2504b..43bf49a08 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -303,6 +303,13 @@ def sentry_event_set_user(self, event, hint): } return event + @classmethod + def get_openapi_extra(cls) -> dict[str, typing.Any]: + return { + "x-fern-sdk-group-name": cls.slug_versions[-1].title().replace("-", ""), + "x-fern-sdk-method-name": "status", + } + def refresh_state(self): _, run_id, uid = extract_query_params(gui.get_query_params()) channel = self.realtime_channel_name(run_id, uid) diff --git a/routers/api.py b/routers/api.py index 94011f07c..65484cd5b 100644 --- a/routers/api.py +++ b/routers/api.py @@ -5,7 +5,7 @@ import typing from types import SimpleNamespace -from fastapi import APIRouter +from fastapi import APIRouter, Query from fastapi import Depends from fastapi import Form from fastapi import HTTPException @@ -215,6 +215,7 @@ def run_api_form( name=page_cls.title + " (v3 async)", tags=[page_cls.title], status_code=202, + openapi_extra=page_cls.get_openapi_extra(), ) @app.post( os.path.join(endpoint, "async"), @@ -227,7 +228,7 @@ def run_api_json_async( request: Request, response: Response, page_request: request_model, - example_id: str | None = None, + example_id: str | None = Query(default=None), user: AppUser = Depends(api_auth_header), ): ret = _run_api( @@ -286,7 +287,10 @@ def run_api_form_async( operation_id="status__" + page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v3 status)", - openapi_extra={"x-fern-sdk-return-value": "output"}, + openapi_extra={ + "x-fern-sdk-return-value": "output", + **page_cls.get_openapi_extra(), + }, ) @app.get( os.path.join(endpoint, "status"),