From edf9aebff2691203bb5d83cbfaaf06ba9eceec03 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Wed, 8 Jan 2025 14:49:42 -0800 Subject: [PATCH] [BUG]: fix bad OpenAPI generation --- chromadb/server/fastapi/__init__.py | 91 +++++++++++++++++++---------- 1 file changed, 59 insertions(+), 32 deletions(-) diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 4f8aeca38a8..40508d019ba 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -16,6 +16,7 @@ CapacityLimiter, ) from fastapi import FastAPI as _FastAPI, Response, Request +from fastapi.openapi.utils import get_openapi from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse from fastapi.routing import APIRoute @@ -142,18 +143,6 @@ def validate_model(model: Type[D], data: Any) -> D: # type: ignore return model.parse_obj(data) # pydantic 1.x -def get_openapi_extras_for_model(request_model: Type[D]) -> Dict[str, Any]: - openapi_extra = { - "requestBody": { - "content": { - "application/json": {"schema": request_model.model_json_schema()} - }, - "required": True, - } - } - return openapi_extra - - class ChromaAPIRouter(fastapi.APIRouter): # type: ignore # A simple subclass of fastapi's APIRouter which treats URLs with a # trailing "/" the same as URLs without. Docs will only contain URLs @@ -189,6 +178,10 @@ def __init__(self, settings: Settings): self._app = fastapi.FastAPI(debug=True, default_response_class=ORJSONResponse) self._system = System(settings) self._api: ServerAPI = self._system.instance(ServerAPI) + + self._extra_openapi_schemas: Dict[str, Any] = {} + self._app.openapi = self.generate_openapi + self._opentelemetry_client = self._api.require(OpenTelemetryClient) self._capacity_limiter = CapacityLimiter( settings.chroma_server_thread_pool_size @@ -232,6 +225,37 @@ def __init__(self, settings: Settings): telemetry_client = self._system.instance(ProductTelemetryClient) telemetry_client.capture(ServerStartEvent()) + def generate_openapi(self) -> Dict[str, Any]: + """Used instead of the default openapi() generation handler to include manually-populated schemas.""" + schema: Dict[str, Any] = get_openapi( + title="Chroma", + routes=self._app.routes, + version="0.1.0", # todo + ) + + for key, value in self._extra_openapi_schemas.items(): + schema["components"]["schemas"][key] = value + + return schema + + def get_openapi_extras_for_body_model( + self, request_model: Type[D] + ) -> Dict[str, Any]: + schema = request_model.model_json_schema( + ref_template="#/components/schemas/{model}" + ) + if "$defs" in schema: + for key, value in schema["$defs"].items(): + self._extra_openapi_schemas[key] = value + + openapi_extra = { + "requestBody": { + "content": {"application/json": {"schema": schema}}, + "required": True, + } + } + return openapi_extra + def setup_v2_routes(self) -> None: self.router.add_api_route("/api/v2", self.root, methods=["GET"]) self.router.add_api_route("/api/v2/reset", self.reset, methods=["POST"]) @@ -253,7 +277,7 @@ def setup_v2_routes(self) -> None: self.create_database, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateDatabase), + openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase), ) self.router.add_api_route( @@ -268,7 +292,7 @@ def setup_v2_routes(self) -> None: self.create_tenant, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateTenant), + openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant), ) self.router.add_api_route( @@ -295,7 +319,7 @@ def setup_v2_routes(self) -> None: self.create_collection, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateCollection), + openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection), ) self.router.add_api_route( @@ -304,35 +328,35 @@ def setup_v2_routes(self) -> None: methods=["POST"], status_code=status.HTTP_201_CREATED, response_model=None, - openapi_extra=get_openapi_extras_for_model(AddEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update", self.update, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(UpdateEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert", self.upsert, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(AddEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/get", self.get, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(GetEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete", self.delete, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(DeleteEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/count", @@ -345,7 +369,9 @@ def setup_v2_routes(self) -> None: self.get_nearest_neighbors, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(request_model=QueryEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model( + request_model=QueryEmbedding + ), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}", @@ -358,7 +384,7 @@ def setup_v2_routes(self) -> None: self.update_collection, methods=["PUT"], response_model=None, - openapi_extra=get_openapi_extras_for_model(UpdateCollection), + openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection), ) self.router.add_api_route( "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}", @@ -1138,7 +1164,7 @@ def setup_v1_routes(self) -> None: self.create_database_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateDatabase), + openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase), ) self.router.add_api_route( @@ -1153,7 +1179,7 @@ def setup_v1_routes(self) -> None: self.create_tenant_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateTenant), + openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant), ) self.router.add_api_route( @@ -1180,7 +1206,7 @@ def setup_v1_routes(self) -> None: self.create_collection_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(CreateCollection), + openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection), ) self.router.add_api_route( @@ -1189,35 +1215,35 @@ def setup_v1_routes(self) -> None: methods=["POST"], status_code=status.HTTP_201_CREATED, response_model=None, - openapi_extra=get_openapi_extras_for_model(AddEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/update", self.update_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(UpdateEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/upsert", self.upsert_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(AddEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/get", self.get_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(GetEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/delete", self.delete_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(DeleteEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_id}/count", @@ -1230,7 +1256,7 @@ def setup_v1_routes(self) -> None: self.get_nearest_neighbors_v1, methods=["POST"], response_model=None, - openapi_extra=get_openapi_extras_for_model(QueryEmbedding), + openapi_extra=self.get_openapi_extras_for_body_model(QueryEmbedding), ) self.router.add_api_route( "/api/v1/collections/{collection_name}", @@ -1243,7 +1269,7 @@ def setup_v1_routes(self) -> None: self.update_collection_v1, methods=["PUT"], response_model=None, - openapi_extra=get_openapi_extras_for_model(UpdateCollection), + openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection), ) self.router.add_api_route( "/api/v1/collections/{collection_name}", @@ -1598,6 +1624,7 @@ async def inner(): ), ) return api_collection_model + return await inner() @trace_method("FastAPI.update_collection_v1", OpenTelemetryGranularity.OPERATION)