Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: fix bad OpenAPI generation #3445

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 60 additions & 32 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CapacityLimiter,
)
from fastapi import FastAPI as _FastAPI, Response, Request
from fastapi.openapi.utils import get_openapi
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse
from fastapi.routing import APIRoute
Expand Down Expand Up @@ -61,6 +62,7 @@
)
from starlette.datastructures import Headers
import logging
import importlib.metadata

from chromadb.telemetry.product.events import ServerStartEvent
from chromadb.utils.fastapi import fastapi_json_response, string_to_uuid as _uuid
Expand Down Expand Up @@ -142,18 +144,6 @@ def validate_model(model: Type[D], data: Any) -> D: # type: ignore
return model.parse_obj(data) # pydantic 1.x


def get_openapi_extras_for_model(request_model: Type[D]) -> Dict[str, Any]:
openapi_extra = {
"requestBody": {
"content": {
"application/json": {"schema": request_model.model_json_schema()}
},
"required": True,
}
}
return openapi_extra


class ChromaAPIRouter(fastapi.APIRouter): # type: ignore
# A simple subclass of fastapi's APIRouter which treats URLs with a
# trailing "/" the same as URLs without. Docs will only contain URLs
Expand Down Expand Up @@ -189,6 +179,10 @@ def __init__(self, settings: Settings):
self._app = fastapi.FastAPI(debug=True, default_response_class=ORJSONResponse)
self._system = System(settings)
self._api: ServerAPI = self._system.instance(ServerAPI)

self._extra_openapi_schemas: Dict[str, Any] = {}
self._app.openapi = self.generate_openapi

self._opentelemetry_client = self._api.require(OpenTelemetryClient)
self._capacity_limiter = CapacityLimiter(
settings.chroma_server_thread_pool_size
Expand Down Expand Up @@ -232,6 +226,37 @@ def __init__(self, settings: Settings):
telemetry_client = self._system.instance(ProductTelemetryClient)
telemetry_client.capture(ServerStartEvent())

def generate_openapi(self) -> Dict[str, Any]:
"""Used instead of the default openapi() generation handler to include manually-populated schemas."""
schema: Dict[str, Any] = get_openapi(
title="Chroma",
routes=self._app.routes,
version=importlib.metadata.version("chromadb"),
)

for key, value in self._extra_openapi_schemas.items():
schema["components"]["schemas"][key] = value

return schema

def get_openapi_extras_for_body_model(
self, request_model: Type[D]
) -> Dict[str, Any]:
schema = request_model.model_json_schema(
ref_template="#/components/schemas/{model}"
)
if "$defs" in schema:
for key, value in schema["$defs"].items():
self._extra_openapi_schemas[key] = value

openapi_extra = {
"requestBody": {
"content": {"application/json": {"schema": schema}},
"required": True,
}
}
return openapi_extra

def setup_v2_routes(self) -> None:
self.router.add_api_route("/api/v2", self.root, methods=["GET"])
self.router.add_api_route("/api/v2/reset", self.reset, methods=["POST"])
Expand All @@ -253,7 +278,7 @@ def setup_v2_routes(self) -> None:
self.create_database,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateDatabase),
openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase),
)

self.router.add_api_route(
Expand All @@ -268,7 +293,7 @@ def setup_v2_routes(self) -> None:
self.create_tenant,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateTenant),
openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant),
)

self.router.add_api_route(
Expand All @@ -295,7 +320,7 @@ def setup_v2_routes(self) -> None:
self.create_collection,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateCollection),
openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection),
)

self.router.add_api_route(
Expand All @@ -304,35 +329,35 @@ def setup_v2_routes(self) -> None:
methods=["POST"],
status_code=status.HTTP_201_CREATED,
response_model=None,
openapi_extra=get_openapi_extras_for_model(AddEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update",
self.update,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(UpdateEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert",
self.upsert,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(AddEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/get",
self.get,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(GetEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete",
self.delete,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(DeleteEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/count",
Expand All @@ -345,7 +370,9 @@ def setup_v2_routes(self) -> None:
self.get_nearest_neighbors,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(request_model=QueryEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(
request_model=QueryEmbedding
),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}",
Expand All @@ -358,7 +385,7 @@ def setup_v2_routes(self) -> None:
self.update_collection,
methods=["PUT"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(UpdateCollection),
openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}",
Expand Down Expand Up @@ -1138,7 +1165,7 @@ def setup_v1_routes(self) -> None:
self.create_database_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateDatabase),
openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase),
)

self.router.add_api_route(
Expand All @@ -1153,7 +1180,7 @@ def setup_v1_routes(self) -> None:
self.create_tenant_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateTenant),
openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant),
)

self.router.add_api_route(
Expand All @@ -1180,7 +1207,7 @@ def setup_v1_routes(self) -> None:
self.create_collection_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(CreateCollection),
openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection),
)

self.router.add_api_route(
Expand All @@ -1189,35 +1216,35 @@ def setup_v1_routes(self) -> None:
methods=["POST"],
status_code=status.HTTP_201_CREATED,
response_model=None,
openapi_extra=get_openapi_extras_for_model(AddEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/update",
self.update_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(UpdateEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/upsert",
self.upsert_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(AddEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/get",
self.get_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(GetEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/delete",
self.delete_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(DeleteEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/count",
Expand All @@ -1230,7 +1257,7 @@ def setup_v1_routes(self) -> None:
self.get_nearest_neighbors_v1,
methods=["POST"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(QueryEmbedding),
openapi_extra=self.get_openapi_extras_for_body_model(QueryEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}",
Expand All @@ -1243,7 +1270,7 @@ def setup_v1_routes(self) -> None:
self.update_collection_v1,
methods=["PUT"],
response_model=None,
openapi_extra=get_openapi_extras_for_model(UpdateCollection),
openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection),
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}",
Expand Down Expand Up @@ -1598,6 +1625,7 @@ async def inner():
),
)
return api_collection_model

return await inner()

@trace_method("FastAPI.update_collection_v1", OpenTelemetryGranularity.OPERATION)
Expand Down
Loading