From 1828a9ce1a3dd637a752de0aaff305af77f08d87 Mon Sep 17 00:00:00 2001 From: Robert Escriva Date: Tue, 14 Jan 2025 08:47:43 -0800 Subject: [PATCH] [BUG] Make auth requests in async context async. (#3477) The auth methods were both sync, which led to latency on the async thread pool. Fix that. --- chromadb/server/fastapi/__init__.py | 100 +++++++++++++++++----------- 1 file changed, 61 insertions(+), 39 deletions(-) diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 0cede0cddc7..7e1bde7683d 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -449,7 +449,17 @@ def _set_request_context(self, request: Request) -> None: "auth_request", OpenTelemetryGranularity.OPERATION, ) - def auth_request( + async def auth_request( + self, + headers: Headers, + action: AuthzAction, + tenant: Optional[str], + database: Optional[str], + collection: Optional[str], + ) -> None: + return await to_thread.run_sync(self.sync_auth_request, *(headers, action, tenant, database, collection)) + + def sync_auth_request( self, headers: Headers, action: AuthzAction, @@ -521,7 +531,7 @@ def process_create_database( ) -> None: db = validate_model(CreateDatabase, orjson.loads(raw_body)) - self.auth_request( + self.sync_auth_request( headers, AuthzAction.CREATE_DATABASE, tenant, @@ -548,7 +558,7 @@ async def get_database( database_name: str, tenant: str, ) -> Database: - self.auth_request( + await self.auth_request( request.headers, AuthzAction.GET_DATABASE, tenant, @@ -596,7 +606,7 @@ async def create_tenant( def process_create_tenant(request: Request, raw_body: bytes) -> None: tenant = validate_model(CreateTenant, orjson.loads(raw_body)) - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.CREATE_TENANT, tenant.name, @@ -604,6 +614,7 @@ def process_create_tenant(request: Request, raw_body: bytes) -> None: None, ) + return self._api.create_tenant(tenant.name) await to_thread.run_sync( @@ -619,7 +630,7 @@ async def get_tenant( request: Request, tenant: str, ) -> Tenant: - self.auth_request( + await self.auth_request( request.headers, AuthzAction.GET_TENANT, tenant, @@ -644,7 +655,7 @@ async def list_databases( limit: Optional[int] = None, offset: Optional[int] = None, ) -> Sequence[Database]: - self.auth_request( + await self.auth_request( request.headers, AuthzAction.LIST_DATABASES, tenant, @@ -675,7 +686,7 @@ async def list_collections( def process_list_collections( limit: Optional[int], offset: Optional[int], tenant: str, database_name: str ) -> Sequence[CollectionModel]: - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.LIST_COLLECTIONS, tenant, @@ -711,7 +722,7 @@ async def count_collections( tenant: str, database_name: str, ) -> int: - self.auth_request( + await self.auth_request( request.headers, AuthzAction.COUNT_COLLECTIONS, tenant, @@ -748,7 +759,7 @@ def process_create_collection( else CollectionConfigurationInternal.from_json(create.configuration) ) - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.CREATE_COLLECTION, tenant, @@ -790,7 +801,7 @@ async def get_collection( database_name: str, collection_name: str, ) -> CollectionModel: - self.auth_request( + await self.auth_request( request.headers, AuthzAction.GET_COLLECTION, tenant, @@ -824,7 +835,7 @@ def process_update_collection( request: Request, collection_id: str, raw_body: bytes ) -> None: update = validate_model(UpdateCollection, orjson.loads(raw_body)) - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.UPDATE_COLLECTION, tenant, @@ -857,7 +868,7 @@ async def delete_collection( tenant: str, database_name: str, ) -> None: - self.auth_request( + await self.auth_request( request.headers, AuthzAction.DELETE_COLLECTION, tenant, @@ -886,7 +897,7 @@ async def add( def process_add(request: Request, raw_body: bytes) -> bool: add = validate_model(AddEmbedding, orjson.loads(raw_body)) - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.ADD, tenant, @@ -934,7 +945,7 @@ async def update( def process_update(request: Request, raw_body: bytes) -> bool: update = validate_model(UpdateEmbedding, orjson.loads(raw_body)) - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.UPDATE, tenant, @@ -975,7 +986,7 @@ async def upsert( def process_upsert(request: Request, raw_body: bytes) -> bool: upsert = validate_model(AddEmbedding, orjson.loads(raw_body)) - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.UPSERT, tenant, @@ -1018,7 +1029,7 @@ async def get( ) -> GetResult: def process_get(request: Request, raw_body: bytes) -> GetResult: get = validate_model(GetEmbedding, orjson.loads(raw_body)) - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.GET, tenant, @@ -1068,7 +1079,7 @@ async def delete( ) -> None: def process_delete(request: Request, raw_body: bytes) -> None: delete = validate_model(DeleteEmbedding, orjson.loads(raw_body)) - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.DELETE, tenant, @@ -1101,7 +1112,7 @@ async def count( database_name: str, collection_id: str, ) -> int: - self.auth_request( + await self.auth_request( request.headers, AuthzAction.COUNT, tenant, @@ -1126,7 +1137,7 @@ async def reset( self, request: Request, ) -> bool: - self.auth_request( + await self.auth_request( request.headers, AuthzAction.RESET, None, @@ -1153,7 +1164,7 @@ async def get_nearest_neighbors( def process_query(request: Request, raw_body: bytes) -> QueryResult: query = validate_model(QueryEmbedding, orjson.loads(raw_body)) - self.auth_request( + self.sync_auth_request( request.headers, AuthzAction.QUERY, tenant, @@ -1346,7 +1357,7 @@ def setup_v1_routes(self) -> None: "auth_and_get_tenant_and_database_for_request_v1", OpenTelemetryGranularity.OPERATION, ) - def auth_and_get_tenant_and_database_for_request( + async def auth_and_get_tenant_and_database_for_request( self, headers: Headers, action: AuthzAction, @@ -1368,6 +1379,17 @@ def auth_and_get_tenant_and_database_for_request( (can be overwritten separately) - The user has access to a single tenant and/or single database. """ + return await to_thread.run_sync(self.auth_and_get_tenant_and_database_for_request, headers, action, tenant, database, collection) + + def sync_auth_and_get_tenant_and_database_for_request( + self, + headers: Headers, + action: AuthzAction, + tenant: Optional[str], + database: Optional[str], + collection: Optional[str], + ) -> Tuple[Optional[str], Optional[str]]: + if not self.authn_provider: add_attributes_to_current_span( { @@ -1423,7 +1445,7 @@ def process_create_database( ( maybe_tenant, maybe_database, - ) = self.auth_and_get_tenant_and_database_for_request( + ) = self.sync_auth_and_get_tenant_and_database_for_request( headers, AuthzAction.CREATE_DATABASE, tenant, @@ -1455,7 +1477,7 @@ async def get_database_v1( ( maybe_tenant, maybe_database, - ) = self.auth_and_get_tenant_and_database_for_request( + ) = await self.auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.GET_DATABASE, tenant, @@ -1485,7 +1507,7 @@ async def create_tenant_v1( def process_create_tenant(request: Request, raw_body: bytes) -> None: tenant = validate_model(CreateTenant, orjson.loads(raw_body)) - maybe_tenant, _ = self.auth_and_get_tenant_and_database_for_request( + maybe_tenant, _ = self.sync_auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.CREATE_TENANT, tenant.name, @@ -1510,7 +1532,7 @@ async def get_tenant_v1( request: Request, tenant: str, ) -> Tenant: - maybe_tenant, _ = self.auth_and_get_tenant_and_database_for_request( + maybe_tenant, _ = await self.auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.GET_TENANT, tenant, @@ -1541,7 +1563,7 @@ async def list_collections_v1( ( maybe_tenant, maybe_database, - ) = self.auth_and_get_tenant_and_database_for_request( + ) = await self.auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.LIST_COLLECTIONS, tenant, @@ -1577,7 +1599,7 @@ async def count_collections_v1( ( maybe_tenant, maybe_database, - ) = self.auth_and_get_tenant_and_database_for_request( + ) = await self.auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.COUNT_COLLECTIONS, tenant, @@ -1619,7 +1641,7 @@ def process_create_collection( ( maybe_tenant, maybe_database, - ) = self.auth_and_get_tenant_and_database_for_request( + ) = self.sync_auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.CREATE_COLLECTION, tenant, @@ -1664,7 +1686,7 @@ async def get_collection_v1( ( maybe_tenant, maybe_database, - ) = self.auth_and_get_tenant_and_database_for_request( + ) = await self.auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.GET_COLLECTION, tenant, @@ -1701,7 +1723,7 @@ def process_update_collection( request: Request, collection_id: str, raw_body: bytes ) -> None: update = validate_model(UpdateCollection, orjson.loads(raw_body)) - self.auth_and_get_tenant_and_database_for_request( + self.sync_auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.UPDATE_COLLECTION, None, @@ -1733,7 +1755,7 @@ async def delete_collection_v1( ( maybe_tenant, maybe_database, - ) = self.auth_and_get_tenant_and_database_for_request( + ) = await self.auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.DELETE_COLLECTION, tenant, @@ -1763,7 +1785,7 @@ async def add_v1( def process_add(request: Request, raw_body: bytes) -> bool: add = validate_model(AddEmbedding, orjson.loads(raw_body)) - self.auth_and_get_tenant_and_database_for_request( + self.sync_auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.ADD, None, @@ -1805,7 +1827,7 @@ async def update_v1( def process_update(request: Request, raw_body: bytes) -> bool: update = validate_model(UpdateEmbedding, orjson.loads(raw_body)) - self.auth_and_get_tenant_and_database_for_request( + self.sync_auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.UPDATE, None, @@ -1840,7 +1862,7 @@ async def upsert_v1( def process_upsert(request: Request, raw_body: bytes) -> bool: upsert = validate_model(AddEmbedding, orjson.loads(raw_body)) - self.auth_and_get_tenant_and_database_for_request( + self.sync_auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.UPSERT, None, @@ -1877,7 +1899,7 @@ async def get_v1( ) -> GetResult: def process_get(request: Request, raw_body: bytes) -> GetResult: get = validate_model(GetEmbedding, orjson.loads(raw_body)) - self.auth_and_get_tenant_and_database_for_request( + self.sync_auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.GET, None, @@ -1921,7 +1943,7 @@ async def delete_v1( ) -> None: def process_delete(request: Request, raw_body: bytes) -> None: delete = validate_model(DeleteEmbedding, orjson.loads(raw_body)) - self.auth_and_get_tenant_and_database_for_request( + self.sync_auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.DELETE, None, @@ -1948,7 +1970,7 @@ async def count_v1( request: Request, collection_id: str, ) -> int: - self.auth_and_get_tenant_and_database_for_request( + await self.auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.COUNT, None, @@ -1970,7 +1992,7 @@ async def reset_v1( self, request: Request, ) -> bool: - self.auth_and_get_tenant_and_database_for_request( + await self.auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.RESET, None, @@ -1997,7 +2019,7 @@ async def get_nearest_neighbors_v1( def process_query(request: Request, raw_body: bytes) -> QueryResult: query = validate_model(QueryEmbedding, orjson.loads(raw_body)) - self.auth_and_get_tenant_and_database_for_request( + self.sync_auth_and_get_tenant_and_database_for_request( request.headers, AuthzAction.QUERY, None,