Skip to content

Commit

Permalink
[ENH] Add a simple rate limiter for async. (#3476)
Browse files Browse the repository at this point in the history
This adds a rate limiter at the level of the fast api async code.  The
reason for doing this is that by the time stuff gets to the pool,
there's a queue.  We want to try and make the rate limiter async so that
we can avoid said queue.
  • Loading branch information
rescrv authored Jan 14, 2025
1 parent 1828a9c commit 82f769c
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 5 deletions.
3 changes: 2 additions & 1 deletion chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.db.system import SysDB
from chromadb.quota import QuotaEnforcer, Action
from chromadb.rate_limit import RateLimitEnforcer
from chromadb.rate_limit import RateLimitEnforcer, AsyncRateLimitEnforcer
from chromadb.segment import SegmentManager
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan, Filter, Limit, KNN, Projection
Expand Down Expand Up @@ -117,6 +117,7 @@ class SegmentAPI(ServerAPI):
_opentelemetry_client: OpenTelemetryClient
_tenant_id: str
_topic_ns: str
_rate_limit_enforcer: RateLimitEnforcer

def __init__(self, system: System):
super().__init__(system)
Expand Down
5 changes: 5 additions & 0 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"chromadb.ingest.Producer": "chroma_producer_impl",
"chromadb.quota.QuotaEnforcer": "chroma_quota_enforcer_impl",
"chromadb.rate_limit.RateLimitEnforcer": "chroma_rate_limit_enforcer_impl",
"chromadb.rate_limit.AsyncRateLimitEnforcer": "chroma_async_rate_limit_enforcer_impl",
"chromadb.segment.SegmentManager": "chroma_segment_manager_impl",
"chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl",
"chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl",
Expand Down Expand Up @@ -257,6 +258,10 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
"chromadb.rate_limit.simple_rate_limit.SimpleRateLimitEnforcer"
)

chroma_async_rate_limit_enforcer_impl: str = (
"chromadb.rate_limit.simple_rate_limit.SimpleAsyncRateLimitEnforcer"
)

# ==========
# gRPC service config
# ==========
Expand Down
23 changes: 20 additions & 3 deletions chromadb/rate_limit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from abc import abstractmethod
from typing import Callable, TypeVar, Any
from typing import Awaitable, Callable, TypeVar, Any
from chromadb.config import Component, System

T = TypeVar("T", bound=Callable[..., Any])
A = TypeVar("A", bound=Awaitable[Any])


class RateLimitEnforcer(Component):
"""
Rate limit enforcer. Implemented as a wrapper around server functions to
block requests if rate limits are exceeded.
Rate limit enforcer.
Implemented as a wrapper around server functions to block requests if rate limits are exceeded.
"""

def __init__(self, system: System) -> None:
Expand All @@ -17,3 +19,18 @@ def __init__(self, system: System) -> None:
@abstractmethod
def rate_limit(self, func: T) -> T:
pass


class AsyncRateLimitEnforcer(Component):
"""
Rate limit enforcer.
Implemented as a wrapper around async functions to block requests if rate limits are exceeded.
"""

def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def rate_limit(self, func: A) -> A:
pass
19 changes: 18 additions & 1 deletion chromadb/rate_limit/simple_rate_limit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from overrides import override
from typing import Any, Callable, TypeVar
from typing import Any, Awaitable, Callable, TypeVar
from functools import wraps

from chromadb.rate_limit import RateLimitEnforcer
from chromadb.config import System

T = TypeVar("T", bound=Callable[..., Any])
A = TypeVar("A", bound=Awaitable[Any])


class SimpleRateLimitEnforcer(RateLimitEnforcer):
Expand All @@ -23,3 +24,19 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

return wrapper # type: ignore


class SimpleAsyncRateLimitEnforcer(RateLimitEnforcer):
"""
A naive implementation of a rate limit enforcer that allows all requests.
"""

def __init__(self, system: System) -> None:
super().__init__(system)

@override
def rate_limit(self, func: A) -> A:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
return await func(*args, **kwargs)
return wrapper # type: ignore
39 changes: 39 additions & 0 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import (
Any,
Awaitable,
Callable,
cast,
Dict,
Expand All @@ -21,6 +22,7 @@
from fastapi.responses import ORJSONResponse
from fastapi.routing import APIRoute
from fastapi import HTTPException, status
from functools import wraps

from chromadb.api.configuration import CollectionConfigurationInternal
from pydantic import BaseModel
Expand Down Expand Up @@ -48,6 +50,7 @@
QuotaError,
)
from chromadb.quota import QuotaEnforcer
from chromadb.rate_limit import AsyncRateLimitEnforcer
from chromadb.server import Server
from chromadb.server.fastapi.types import (
AddEmbedding,
Expand Down Expand Up @@ -82,6 +85,14 @@
logger = logging.getLogger(__name__)


def rate_limit(func):
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
self = args[0]
return await self._async_rate_limit_enforcer.rate_limit(func)(*args, **kwargs)
return wrapper


def use_route_names_as_operation_ids(app: _FastAPI) -> None:
"""
Simplify operation IDs so that generated API clients have simpler function
Expand Down Expand Up @@ -203,6 +214,7 @@ def __init__(self, settings: Settings):
self._app.add_exception_handler(
RateLimitError, self.rate_limit_exception_handler
)
self._async_rate_limit_enforcer = self._system.require(AsyncRateLimitEnforcer)

self._app.on_event("shutdown")(self.shutdown)

Expand Down Expand Up @@ -449,6 +461,7 @@ def _set_request_context(self, request: Request) -> None:
"auth_request",
OpenTelemetryGranularity.OPERATION,
)
@rate_limit
async def auth_request(
self,
headers: Headers,
Expand Down Expand Up @@ -886,6 +899,7 @@ async def delete_collection(
)

@trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def add(
self,
request: Request,
Expand Down Expand Up @@ -935,6 +949,7 @@ def process_add(request: Request, raw_body: bytes) -> bool:
raise HTTPException(status_code=500, detail=str(e))

@trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def update(
self,
request: Request,
Expand Down Expand Up @@ -976,6 +991,7 @@ def process_update(request: Request, raw_body: bytes) -> bool:
)

@trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def upsert(
self,
request: Request,
Expand Down Expand Up @@ -1020,6 +1036,7 @@ def process_upsert(request: Request, raw_body: bytes) -> bool:
)

@trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get(
self,
collection_id: str,
Expand Down Expand Up @@ -1070,6 +1087,7 @@ def process_get(request: Request, raw_body: bytes) -> GetResult:
return get_result

@trace_method("FastAPI.delete", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def delete(
self,
collection_id: str,
Expand Down Expand Up @@ -1105,6 +1123,7 @@ def process_delete(request: Request, raw_body: bytes) -> None:
)

@trace_method("FastAPI.count", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def count(
self,
request: Request,
Expand Down Expand Up @@ -1133,6 +1152,7 @@ async def count(
)

@trace_method("FastAPI.reset", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def reset(
self,
request: Request,
Expand All @@ -1154,6 +1174,7 @@ async def reset(
)

@trace_method("FastAPI.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_nearest_neighbors(
self,
tenant: str,
Expand Down Expand Up @@ -1357,6 +1378,7 @@ def setup_v1_routes(self) -> None:
"auth_and_get_tenant_and_database_for_request_v1",
OpenTelemetryGranularity.OPERATION,
)
@rate_limit
async def auth_and_get_tenant_and_database_for_request(
self,
headers: Headers,
Expand Down Expand Up @@ -1432,6 +1454,7 @@ def sync_auth_and_get_tenant_and_database_for_request(
return (tenant, database)

@trace_method("FastAPI.create_database_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def create_database_v1(
self,
request: Request,
Expand Down Expand Up @@ -1468,6 +1491,7 @@ def process_create_database(
)

@trace_method("FastAPI.get_database_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_database_v1(
self,
request: Request,
Expand Down Expand Up @@ -1500,6 +1524,7 @@ async def get_database_v1(
)

@trace_method("FastAPI.create_tenant_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def create_tenant_v1(
self,
request: Request,
Expand Down Expand Up @@ -1527,6 +1552,7 @@ def process_create_tenant(request: Request, raw_body: bytes) -> None:
)

@trace_method("FastAPI.get_tenant_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_tenant_v1(
self,
request: Request,
Expand All @@ -1552,6 +1578,7 @@ async def get_tenant_v1(
)

@trace_method("FastAPI.list_collections_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def list_collections_v1(
self,
request: Request,
Expand Down Expand Up @@ -1590,6 +1617,7 @@ async def list_collections_v1(
return api_collection_models

@trace_method("FastAPI.count_collections_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def count_collections_v1(
self,
request: Request,
Expand Down Expand Up @@ -1622,6 +1650,7 @@ async def count_collections_v1(
)

@trace_method("FastAPI.create_collection_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def create_collection_v1(
self,
request: Request,
Expand Down Expand Up @@ -1676,6 +1705,7 @@ def process_create_collection(
return api_collection_model

@trace_method("FastAPI.get_collection_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_collection_v1(
self,
request: Request,
Expand Down Expand Up @@ -1714,6 +1744,7 @@ async def inner():
return await inner()

@trace_method("FastAPI.update_collection_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def update_collection_v1(
self,
collection_id: str,
Expand Down Expand Up @@ -1745,6 +1776,7 @@ def process_update_collection(
)

@trace_method("FastAPI.delete_collection_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def delete_collection_v1(
self,
request: Request,
Expand Down Expand Up @@ -1776,6 +1808,7 @@ async def delete_collection_v1(
)

@trace_method("FastAPI.add_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def add_v1(
self,
request: Request,
Expand Down Expand Up @@ -1854,6 +1887,7 @@ def process_update(request: Request, raw_body: bytes) -> bool:
)

@trace_method("FastAPI.upsert_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def upsert_v1(
self,
request: Request,
Expand Down Expand Up @@ -1892,6 +1926,7 @@ def process_upsert(request: Request, raw_body: bytes) -> bool:
)

@trace_method("FastAPI.get_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_v1(
self,
collection_id: str,
Expand Down Expand Up @@ -1936,6 +1971,7 @@ def process_get(request: Request, raw_body: bytes) -> GetResult:
return get_result

@trace_method("FastAPI.delete_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def delete_v1(
self,
collection_id: str,
Expand Down Expand Up @@ -1965,6 +2001,7 @@ def process_delete(request: Request, raw_body: bytes) -> None:
)

@trace_method("FastAPI.count_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def count_v1(
self,
request: Request,
Expand All @@ -1988,6 +2025,7 @@ async def count_v1(
)

@trace_method("FastAPI.reset_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def reset_v1(
self,
request: Request,
Expand All @@ -2011,6 +2049,7 @@ async def reset_v1(
@trace_method(
"FastAPI.get_nearest_neighbors_v1", OpenTelemetryGranularity.OPERATION
)
@rate_limit
async def get_nearest_neighbors_v1(
self,
collection_id: str,
Expand Down

0 comments on commit 82f769c

Please sign in to comment.