Skip to content

Commit

Permalink
[BUG] Make auth requests in async context async. (#3477)
Browse files Browse the repository at this point in the history
The auth methods were both sync, which led to latency on the async
thread pool.  Fix that.
  • Loading branch information
rescrv authored Jan 14, 2025
1 parent 66e8b30 commit 1828a9c
Showing 1 changed file with 61 additions and 39 deletions.
100 changes: 61 additions & 39 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,17 @@ def _set_request_context(self, request: Request) -> None:
"auth_request",
OpenTelemetryGranularity.OPERATION,
)
def auth_request(
async def auth_request(
self,
headers: Headers,
action: AuthzAction,
tenant: Optional[str],
database: Optional[str],
collection: Optional[str],
) -> None:
return await to_thread.run_sync(self.sync_auth_request, *(headers, action, tenant, database, collection))

def sync_auth_request(
self,
headers: Headers,
action: AuthzAction,
Expand Down Expand Up @@ -521,7 +531,7 @@ def process_create_database(
) -> None:
db = validate_model(CreateDatabase, orjson.loads(raw_body))

self.auth_request(
self.sync_auth_request(
headers,
AuthzAction.CREATE_DATABASE,
tenant,
Expand All @@ -548,7 +558,7 @@ async def get_database(
database_name: str,
tenant: str,
) -> Database:
self.auth_request(
await self.auth_request(
request.headers,
AuthzAction.GET_DATABASE,
tenant,
Expand Down Expand Up @@ -596,14 +606,15 @@ async def create_tenant(
def process_create_tenant(request: Request, raw_body: bytes) -> None:
tenant = validate_model(CreateTenant, orjson.loads(raw_body))

self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.CREATE_TENANT,
tenant.name,
None,
None,
)


return self._api.create_tenant(tenant.name)

await to_thread.run_sync(
Expand All @@ -619,7 +630,7 @@ async def get_tenant(
request: Request,
tenant: str,
) -> Tenant:
self.auth_request(
await self.auth_request(
request.headers,
AuthzAction.GET_TENANT,
tenant,
Expand All @@ -644,7 +655,7 @@ async def list_databases(
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Database]:
self.auth_request(
await self.auth_request(
request.headers,
AuthzAction.LIST_DATABASES,
tenant,
Expand Down Expand Up @@ -675,7 +686,7 @@ async def list_collections(
def process_list_collections(
limit: Optional[int], offset: Optional[int], tenant: str, database_name: str
) -> Sequence[CollectionModel]:
self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.LIST_COLLECTIONS,
tenant,
Expand Down Expand Up @@ -711,7 +722,7 @@ async def count_collections(
tenant: str,
database_name: str,
) -> int:
self.auth_request(
await self.auth_request(
request.headers,
AuthzAction.COUNT_COLLECTIONS,
tenant,
Expand Down Expand Up @@ -748,7 +759,7 @@ def process_create_collection(
else CollectionConfigurationInternal.from_json(create.configuration)
)

self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.CREATE_COLLECTION,
tenant,
Expand Down Expand Up @@ -790,7 +801,7 @@ async def get_collection(
database_name: str,
collection_name: str,
) -> CollectionModel:
self.auth_request(
await self.auth_request(
request.headers,
AuthzAction.GET_COLLECTION,
tenant,
Expand Down Expand Up @@ -824,7 +835,7 @@ def process_update_collection(
request: Request, collection_id: str, raw_body: bytes
) -> None:
update = validate_model(UpdateCollection, orjson.loads(raw_body))
self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.UPDATE_COLLECTION,
tenant,
Expand Down Expand Up @@ -857,7 +868,7 @@ async def delete_collection(
tenant: str,
database_name: str,
) -> None:
self.auth_request(
await self.auth_request(
request.headers,
AuthzAction.DELETE_COLLECTION,
tenant,
Expand Down Expand Up @@ -886,7 +897,7 @@ async def add(

def process_add(request: Request, raw_body: bytes) -> bool:
add = validate_model(AddEmbedding, orjson.loads(raw_body))
self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.ADD,
tenant,
Expand Down Expand Up @@ -934,7 +945,7 @@ async def update(
def process_update(request: Request, raw_body: bytes) -> bool:
update = validate_model(UpdateEmbedding, orjson.loads(raw_body))

self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.UPDATE,
tenant,
Expand Down Expand Up @@ -975,7 +986,7 @@ async def upsert(
def process_upsert(request: Request, raw_body: bytes) -> bool:
upsert = validate_model(AddEmbedding, orjson.loads(raw_body))

self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.UPSERT,
tenant,
Expand Down Expand Up @@ -1018,7 +1029,7 @@ async def get(
) -> GetResult:
def process_get(request: Request, raw_body: bytes) -> GetResult:
get = validate_model(GetEmbedding, orjson.loads(raw_body))
self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.GET,
tenant,
Expand Down Expand Up @@ -1068,7 +1079,7 @@ async def delete(
) -> None:
def process_delete(request: Request, raw_body: bytes) -> None:
delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.DELETE,
tenant,
Expand Down Expand Up @@ -1101,7 +1112,7 @@ async def count(
database_name: str,
collection_id: str,
) -> int:
self.auth_request(
await self.auth_request(
request.headers,
AuthzAction.COUNT,
tenant,
Expand All @@ -1126,7 +1137,7 @@ async def reset(
self,
request: Request,
) -> bool:
self.auth_request(
await self.auth_request(
request.headers,
AuthzAction.RESET,
None,
Expand All @@ -1153,7 +1164,7 @@ async def get_nearest_neighbors(
def process_query(request: Request, raw_body: bytes) -> QueryResult:
query = validate_model(QueryEmbedding, orjson.loads(raw_body))

self.auth_request(
self.sync_auth_request(
request.headers,
AuthzAction.QUERY,
tenant,
Expand Down Expand Up @@ -1346,7 +1357,7 @@ def setup_v1_routes(self) -> None:
"auth_and_get_tenant_and_database_for_request_v1",
OpenTelemetryGranularity.OPERATION,
)
def auth_and_get_tenant_and_database_for_request(
async def auth_and_get_tenant_and_database_for_request(
self,
headers: Headers,
action: AuthzAction,
Expand All @@ -1368,6 +1379,17 @@ def auth_and_get_tenant_and_database_for_request(
(can be overwritten separately)
- The user has access to a single tenant and/or single database.
"""
return await to_thread.run_sync(self.auth_and_get_tenant_and_database_for_request, headers, action, tenant, database, collection)

def sync_auth_and_get_tenant_and_database_for_request(
self,
headers: Headers,
action: AuthzAction,
tenant: Optional[str],
database: Optional[str],
collection: Optional[str],
) -> Tuple[Optional[str], Optional[str]]:

if not self.authn_provider:
add_attributes_to_current_span(
{
Expand Down Expand Up @@ -1423,7 +1445,7 @@ def process_create_database(
(
maybe_tenant,
maybe_database,
) = self.auth_and_get_tenant_and_database_for_request(
) = self.sync_auth_and_get_tenant_and_database_for_request(
headers,
AuthzAction.CREATE_DATABASE,
tenant,
Expand Down Expand Up @@ -1455,7 +1477,7 @@ async def get_database_v1(
(
maybe_tenant,
maybe_database,
) = self.auth_and_get_tenant_and_database_for_request(
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET_DATABASE,
tenant,
Expand Down Expand Up @@ -1485,7 +1507,7 @@ async def create_tenant_v1(
def process_create_tenant(request: Request, raw_body: bytes) -> None:
tenant = validate_model(CreateTenant, orjson.loads(raw_body))

maybe_tenant, _ = self.auth_and_get_tenant_and_database_for_request(
maybe_tenant, _ = self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.CREATE_TENANT,
tenant.name,
Expand All @@ -1510,7 +1532,7 @@ async def get_tenant_v1(
request: Request,
tenant: str,
) -> Tenant:
maybe_tenant, _ = self.auth_and_get_tenant_and_database_for_request(
maybe_tenant, _ = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET_TENANT,
tenant,
Expand Down Expand Up @@ -1541,7 +1563,7 @@ async def list_collections_v1(
(
maybe_tenant,
maybe_database,
) = self.auth_and_get_tenant_and_database_for_request(
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.LIST_COLLECTIONS,
tenant,
Expand Down Expand Up @@ -1577,7 +1599,7 @@ async def count_collections_v1(
(
maybe_tenant,
maybe_database,
) = self.auth_and_get_tenant_and_database_for_request(
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.COUNT_COLLECTIONS,
tenant,
Expand Down Expand Up @@ -1619,7 +1641,7 @@ def process_create_collection(
(
maybe_tenant,
maybe_database,
) = self.auth_and_get_tenant_and_database_for_request(
) = self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.CREATE_COLLECTION,
tenant,
Expand Down Expand Up @@ -1664,7 +1686,7 @@ async def get_collection_v1(
(
maybe_tenant,
maybe_database,
) = self.auth_and_get_tenant_and_database_for_request(
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET_COLLECTION,
tenant,
Expand Down Expand Up @@ -1701,7 +1723,7 @@ def process_update_collection(
request: Request, collection_id: str, raw_body: bytes
) -> None:
update = validate_model(UpdateCollection, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.UPDATE_COLLECTION,
None,
Expand Down Expand Up @@ -1733,7 +1755,7 @@ async def delete_collection_v1(
(
maybe_tenant,
maybe_database,
) = self.auth_and_get_tenant_and_database_for_request(
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.DELETE_COLLECTION,
tenant,
Expand Down Expand Up @@ -1763,7 +1785,7 @@ async def add_v1(

def process_add(request: Request, raw_body: bytes) -> bool:
add = validate_model(AddEmbedding, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.ADD,
None,
Expand Down Expand Up @@ -1805,7 +1827,7 @@ async def update_v1(
def process_update(request: Request, raw_body: bytes) -> bool:
update = validate_model(UpdateEmbedding, orjson.loads(raw_body))

self.auth_and_get_tenant_and_database_for_request(
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.UPDATE,
None,
Expand Down Expand Up @@ -1840,7 +1862,7 @@ async def upsert_v1(
def process_upsert(request: Request, raw_body: bytes) -> bool:
upsert = validate_model(AddEmbedding, orjson.loads(raw_body))

self.auth_and_get_tenant_and_database_for_request(
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.UPSERT,
None,
Expand Down Expand Up @@ -1877,7 +1899,7 @@ async def get_v1(
) -> GetResult:
def process_get(request: Request, raw_body: bytes) -> GetResult:
get = validate_model(GetEmbedding, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET,
None,
Expand Down Expand Up @@ -1921,7 +1943,7 @@ async def delete_v1(
) -> None:
def process_delete(request: Request, raw_body: bytes) -> None:
delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.DELETE,
None,
Expand All @@ -1948,7 +1970,7 @@ async def count_v1(
request: Request,
collection_id: str,
) -> int:
self.auth_and_get_tenant_and_database_for_request(
await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.COUNT,
None,
Expand All @@ -1970,7 +1992,7 @@ async def reset_v1(
self,
request: Request,
) -> bool:
self.auth_and_get_tenant_and_database_for_request(
await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.RESET,
None,
Expand All @@ -1997,7 +2019,7 @@ async def get_nearest_neighbors_v1(
def process_query(request: Request, raw_body: bytes) -> QueryResult:
query = validate_model(QueryEmbedding, orjson.loads(raw_body))

self.auth_and_get_tenant_and_database_for_request(
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.QUERY,
None,
Expand Down

0 comments on commit 1828a9c

Please sign in to comment.