From eda4f8fbae0c9ec08d553ca0f83e1041e7ef1fc5 Mon Sep 17 00:00:00 2001 From: Vitaliy Kucheryaviy Date: Wed, 26 Jun 2024 16:00:30 +0300 Subject: [PATCH 1/7] WIP throttling --- ninja/conf.py | 35 +++-- ninja/errors.py | 6 + ninja/main.py | 21 +++ ninja/operation.py | 71 ++++++++-- ninja/router.py | 45 ++++++- ninja/testing/client.py | 3 +- ninja/throttling.py | 226 +++++++++++++++++++++++++++++++ tests/test_throttling.py | 279 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 653 insertions(+), 33 deletions(-) create mode 100644 ninja/throttling.py create mode 100644 tests/test_throttling.py diff --git a/ninja/conf.py b/ninja/conf.py index ca378d25c..d402cdcc1 100644 --- a/ninja/conf.py +++ b/ninja/conf.py @@ -1,30 +1,27 @@ from math import inf +from typing import Dict, Optional from django.conf import settings as django_settings from pydantic import BaseModel, Field class Settings(BaseModel): - """ - Alter these by modifying the values in Django's settings module (usually - `settings.py`). - - Attributes: - NINJA_PAGINATION_CLASS (str): - The pagination class to use. Defaults to - `ninja.pagination.LimitOffsetPagination`. - NINJA_PAGINATION_PER_PAGE (int): - The default page size. Defaults to `100`. - NINJA_PAGINATION_MAX_LIMIT (int): - The maximum number of results per page. Defaults to `inf`. - """ - - PAGINATION_CLASS: str = Field( - "ninja.pagination.LimitOffsetPagination", alias="NINJA_PAGINATION_CLASS" - ) + # Pagination + PAGINATION_CLASS: str = Field("ninja.pagination.LimitOffsetPagination", alias="NINJA_PAGINATION_CLASS") PAGINATION_PER_PAGE: int = Field(100, alias="NINJA_PAGINATION_PER_PAGE") PAGINATION_MAX_LIMIT: int = Field(inf, alias="NINJA_PAGINATION_MAX_LIMIT") + # Throttling + NUM_PROXIES: Optional[int] = Field(None, alias="NINJA_NUM_PROXIES") + DEFAULT_THROTTLE_RATES: Dict[str, Optional[str]] = Field( + { + "auth": "10000/day", + "user": "10000/day", + "anon": "1000/day", + }, + alias="NINJA_DEFAULT_THROTTLE_RATES", + ) + class Config: from_attributes = True @@ -32,6 +29,4 @@ class Config: settings = Settings.model_validate(django_settings) if hasattr(django_settings, "NINJA_DOCS_VIEW"): - raise Exception( - "NINJA_DOCS_VIEW is removed. Use NinjaAPI(docs=...) instead" - ) # pragma: no cover + raise Exception("NINJA_DOCS_VIEW is removed. Use NinjaAPI(docs=...) instead") # pragma: no cover diff --git a/ninja/errors.py b/ninja/errors.py index 3c1056a7e..f938e0635 100644 --- a/ninja/errors.py +++ b/ninja/errors.py @@ -53,6 +53,12 @@ def __str__(self) -> str: return self.message +class Throttled(HttpError): + def __init__(self, wait: int) -> None: + self.wait = wait + super().__init__(status_code=429, message="Too many requests.") + + def set_default_exc_handlers(api: "NinjaAPI") -> None: api.add_exception_handler( Exception, diff --git a/ninja/main.py b/ninja/main.py index 0abd02a95..b6ec5816f 100644 --- a/ninja/main.py +++ b/ninja/main.py @@ -27,6 +27,7 @@ from ninja.parser import Parser from ninja.renderers import BaseRenderer, JSONRenderer from ninja.router import Router +from ninja.throttling import BaseThrottle from ninja.types import DictStrAny, TCallable from ninja.utils import is_debug_server, normalize_path @@ -61,6 +62,7 @@ def __init__( urls_namespace: Optional[str] = None, csrf: bool = False, auth: Optional[Union[Sequence[Callable], Callable, NOT_SET_TYPE]] = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, renderer: Optional[BaseRenderer] = None, parser: Optional[Parser] = None, default_router: Optional[Router] = None, @@ -111,6 +113,8 @@ def __init__( else: self.auth = auth + self.throttle = throttle + self._routers: List[Tuple[str, Router]] = [] self.default_router = default_router or Router() self.add_router("", self.default_router) @@ -120,6 +124,7 @@ def get( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -141,6 +146,7 @@ def get( return self.default_router.get( path, auth=auth is NOT_SET and self.auth or auth, + throttle=throttle is NOT_SET and self.throttle or throttle, response=response, operation_id=operation_id, summary=summary, @@ -161,6 +167,7 @@ def post( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -182,6 +189,7 @@ def post( return self.default_router.post( path, auth=auth is NOT_SET and self.auth or auth, + throttle=throttle is NOT_SET and self.throttle or throttle, response=response, operation_id=operation_id, summary=summary, @@ -202,6 +210,7 @@ def delete( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -223,6 +232,7 @@ def delete( return self.default_router.delete( path, auth=auth is NOT_SET and self.auth or auth, + throttle=throttle is NOT_SET and self.throttle or throttle, response=response, operation_id=operation_id, summary=summary, @@ -243,6 +253,7 @@ def patch( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -264,6 +275,7 @@ def patch( return self.default_router.patch( path, auth=auth is NOT_SET and self.auth or auth, + throttle=throttle is NOT_SET and self.throttle or throttle, response=response, operation_id=operation_id, summary=summary, @@ -284,6 +296,7 @@ def put( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -305,6 +318,7 @@ def put( return self.default_router.put( path, auth=auth is NOT_SET and self.auth or auth, + throttle=throttle is NOT_SET and self.throttle or throttle, response=response, operation_id=operation_id, summary=summary, @@ -326,6 +340,7 @@ def api_operation( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -344,6 +359,7 @@ def api_operation( methods, path, auth=auth is NOT_SET and self.auth or auth, + throttle=throttle is NOT_SET and self.throttle or throttle, response=response, operation_id=operation_id, summary=summary, @@ -365,6 +381,7 @@ def add_router( router: Union[Router, str], *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, tags: Optional[List[str]] = None, parent_router: Optional[Router] = None, ) -> None: @@ -374,6 +391,10 @@ def add_router( if auth is not NOT_SET: router.auth = auth + + if throttle is not NOT_SET: + router.throttle = throttle + if tags is not None: router.tags = tags diff --git a/ninja/operation.py b/ninja/operation.py index bad707831..cdb407cfb 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -18,11 +18,12 @@ from django.http import HttpRequest, HttpResponse, HttpResponseNotAllowed from django.http.response import HttpResponseBase -from ninja.constants import NOT_SET -from ninja.errors import AuthenticationError, ConfigError, ValidationError +from ninja.constants import NOT_SET, NOT_SET_TYPE +from ninja.errors import AuthenticationError, ConfigError, Throttled, ValidationError from ninja.params.models import TModels from ninja.schema import Schema from ninja.signature import ViewSignature, is_async +from ninja.throttling import BaseThrottle from ninja.types import DictStrAny from ninja.utils import check_csrf, is_async_callable @@ -39,7 +40,8 @@ def __init__( methods: List[str], view_func: Callable, *, - auth: Optional[Union[Sequence[Callable], Callable, object]] = NOT_SET, + auth: Optional[Union[Sequence[Callable], Callable, NOT_SET_TYPE]] = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -66,6 +68,17 @@ def __init__( self.auth_callbacks: Sequence[Callable] = [] self._set_auth(auth) + if isinstance(throttle, BaseThrottle): + throttle = [throttle] + self.throttle_param = throttle + self.throttle_objects: List[BaseThrottle] = [] + if throttle is not NOT_SET: + for th in throttle: + assert isinstance( + th, BaseThrottle + ), "Throttle should be an instance of BaseThrottle" + self.throttle_objects.append(th) + self.signature = ViewSignature(self.path, self.view_func) self.models: TModels = self.signature.models @@ -92,7 +105,7 @@ def __init__( self.exclude_none = exclude_none if hasattr(view_func, "_ninja_contribute_to_operation"): - # Allow 3rd party code to contribute to the operation behaviour + # Allow 3rd party code to contribute to the operation behavior callbacks: List[Callable] = view_func._ninja_contribute_to_operation for callback in callbacks: callback(self) @@ -115,12 +128,30 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None: self.api = api + if self.auth_param == NOT_SET: if api.auth != NOT_SET: self._set_auth(self.api.auth) if router.auth != NOT_SET: self._set_auth(router.auth) + if self.throttle_param == NOT_SET: + if api.throttle != NOT_SET: + self.throttle_objects = ( + isinstance(api.throttle, BaseThrottle) + and [api.throttle] + or api.throttle + ) + if router.throttle != NOT_SET: + self.throttle_objects = ( + isinstance(router.throttle, BaseThrottle) + and [router.throttle] + or router.throttle + ) + assert all( + isinstance(th, BaseThrottle) for th in self.throttle_objects + ), "Throttle should be an instance of BaseThrottle" + if self.tags is None: if router.tags is not None: self.tags = router.tags @@ -132,16 +163,23 @@ def _set_auth( self.auth_callbacks = isinstance(auth, Sequence) and auth or [auth] def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: - "Runs security checks for each operation" + "Runs security/throttle checks for each operation" + + # csrf: + if self.api.csrf: + error = check_csrf(request, self.view_func) + if error: + return error + # auth: if self.auth_callbacks: error = self._run_authentication(request) if error: return error - # csrf: - if self.api.csrf: - error = check_csrf(request, self.view_func) + # Throttling: + if self.throttle_objects: + error = self._check_throttles(request) if error: return error @@ -162,6 +200,21 @@ def _run_authentication(self, request: HttpRequest) -> Optional[HttpResponse]: return None return self.api.on_exception(request, AuthenticationError()) + def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]: + throttle_durations = [] + for throttle in self.throttle_objects: + if not throttle.allow_request(request): + throttle_durations.append(throttle.wait()) + + if throttle_durations: + # Filter out `None` values which may happen in case of config / rate + durations = [ + duration for duration in throttle_durations if duration is not None + ] + + duration = max(durations, default=None) + return self.api.on_exception(request, Throttled(wait=duration)) + def _result_to_response( self, request: HttpRequest, result: Any, temporal_response: HttpResponse ) -> HttpResponseBase: @@ -332,6 +385,7 @@ def add_operation( view_func: Callable, *, auth: Optional[Union[Sequence[Callable], Callable, object]] = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -359,6 +413,7 @@ def add_operation( methods, view_func, auth=auth, + throttle=throttle, response=response, operation_id=operation_id, summary=summary, diff --git a/ninja/router.py b/ninja/router.py index 7b5acfecd..72ade271d 100644 --- a/ninja/router.py +++ b/ninja/router.py @@ -1,11 +1,22 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, +) from django.urls import URLPattern from django.urls import path as django_path -from ninja.constants import NOT_SET +from ninja.constants import NOT_SET, NOT_SET_TYPE from ninja.errors import ConfigError from ninja.operation import PathView +from ninja.throttling import BaseThrottle from ninja.types import TCallable from ninja.utils import normalize_path, replace_path_param_notation @@ -18,10 +29,15 @@ class Router: def __init__( - self, *, auth: Any = NOT_SET, tags: Optional[List[str]] = None + self, + *, + auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, + tags: Optional[List[str]] = None, ) -> None: self.api: Optional["NinjaAPI"] = None self.auth = auth + self.throttle: List[BaseThrottle] = throttle self.tags = tags self.path_operations: Dict[str, PathView] = {} self._routers: List[Tuple[str, Router]] = [] @@ -31,6 +47,7 @@ def get( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -49,6 +66,7 @@ def get( ["GET"], path, auth=auth, + throttle=throttle, response=response, operation_id=operation_id, summary=summary, @@ -69,6 +87,7 @@ def post( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -87,6 +106,7 @@ def post( ["POST"], path, auth=auth, + throttle=throttle, response=response, operation_id=operation_id, summary=summary, @@ -107,6 +127,7 @@ def delete( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -125,6 +146,7 @@ def delete( ["DELETE"], path, auth=auth, + throttle=throttle, response=response, operation_id=operation_id, summary=summary, @@ -145,6 +167,7 @@ def patch( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -163,6 +186,7 @@ def patch( ["PATCH"], path, auth=auth, + throttle=throttle, response=response, operation_id=operation_id, summary=summary, @@ -183,6 +207,7 @@ def put( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -201,6 +226,7 @@ def put( ["PUT"], path, auth=auth, + throttle=throttle, response=response, operation_id=operation_id, summary=summary, @@ -222,6 +248,7 @@ def api_operation( path: str, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -242,6 +269,7 @@ def decorator(view_func: TCallable) -> TCallable: methods, view_func, auth=auth, + throttle=throttle, response=response, operation_id=operation_id, summary=summary, @@ -267,6 +295,7 @@ def add_api_operation( view_func: Callable, *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, @@ -291,6 +320,7 @@ def add_api_operation( methods=methods, view_func=view_func, auth=auth, + throttle=throttle, response=response, operation_id=operation_id, summary=summary, @@ -344,17 +374,24 @@ def add_router( router: "Router", *, auth: Any = NOT_SET, + throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, tags: Optional[List[str]] = None, ) -> None: if self.api: # we are already attached to an api self.api.add_router( - prefix=prefix, router=router, auth=auth, tags=tags, parent_router=self + prefix=prefix, + router=router, + auth=auth, + throttle=throttle, + tags=tags, + parent_router=self, ) else: # we are not attached to an api if auth != NOT_SET: router.auth = auth + # TODO: throttle if tags is not None: router.tags = tags self._routers.append((prefix, router)) diff --git a/ninja/testing/client.py b/ninja/testing/client.py index c0b81f1b4..0be8ce0a8 100644 --- a/ninja/testing/client.py +++ b/ninja/testing/client.py @@ -120,10 +120,11 @@ def _build_request( request.is_secure.return_value = False request.build_absolute_uri = build_absolute_uri + request.auth = None if "user" not in request_params: request.user.is_authenticated = False - request.META = request_params.pop("META", {}) + request.META = request_params.pop("META", {"REMOTE_ADDR": "127.0.0.1"}) request.FILES = request_params.pop("FILES", {}) request.META.update( diff --git a/ninja/throttling.py b/ninja/throttling.py new file mode 100644 index 000000000..5a59ff1f2 --- /dev/null +++ b/ninja/throttling.py @@ -0,0 +1,226 @@ +""" +Provides various throttling policies. +""" + +import time + +from django.core.cache import cache as default_cache +from django.core.exceptions import ImproperlyConfigured + + +class BaseThrottle: + """ + Rate throttling of requests. + """ + + def allow_request(self, request): + """ + Return `True` if the request should be allowed, `False` otherwise. + """ + raise NotImplementedError(".allow_request() must be overridden") + + def get_ident(self, request): + """ + Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR + if present and number of proxies is > 0. If not use all of + HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. + """ + from ninja.conf import settings + + xff = request.META.get("HTTP_X_FORWARDED_FOR") + remote_addr = request.META.get("REMOTE_ADDR") + num_proxies = settings.NUM_PROXIES + + if num_proxies is not None: + if num_proxies == 0 or xff is None: + return remote_addr + addrs = xff.split(",") + client_addr = addrs[-min(num_proxies, len(addrs))] + return client_addr.strip() + + return "".join(xff.split()) if xff else remote_addr + + def wait(self): + """ + Optionally, return a recommended number of seconds to wait before + the next request. + """ + return None + + +class SimpleRateThrottle(BaseThrottle): + """ + A simple cache implementation, that only requires `.get_cache_key()` + to be overridden. + + The rate (requests / seconds) is set by a `rate` attribute on the Throttle + class. The attribute is a string of the form 'number_of_requests/period'. + + Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') + + Previous request information used for throttling is stored in the cache. + """ + + from ninja.conf import settings + + cache = default_cache + timer = time.time + cache_format = "throttle_%(scope)s_%(ident)s" + scope = None + THROTTLE_RATES = settings.DEFAULT_THROTTLE_RATES + + def __init__(self, rate=None): + if rate: + self.rate = rate + else: + self.rate = self.get_rate() + self.num_requests, self.duration = self.parse_rate(self.rate) + + def get_cache_key(self, request): + """ + Should return a unique cache-key which can be used for throttling. + Must be overridden. + + May return `None` if the request should not be throttled. + """ + raise NotImplementedError(".get_cache_key() must be overridden") + + def get_rate(self): + """ + Determine the string representation of the allowed request rate. + """ + if not getattr(self, "scope", None): + msg = f"You must set either `.scope` or `.rate` for '{self.__class__.__name__}' throttle" + raise ImproperlyConfigured(msg) + + try: + return self.THROTTLE_RATES[self.scope] + except KeyError: + msg = f"No default throttle rate set for '{self.scope}' scope" + raise ImproperlyConfigured(msg) from None + + def parse_rate(self, rate): + """ + Given the request rate string, return a two tuple of: + , + """ + if rate is None: + return (None, None) + num, period = rate.split("/") + num_requests = int(num) + duration = {"s": 1, "m": 60, "h": 3600, "d": 86400}[period[0]] + return (num_requests, duration) + + def allow_request(self, request): + """ + Implement the check to see if the request should be throttled. + + On success calls `throttle_success`. + On failure calls `throttle_failure`. + """ + # if self.rate is None: + # return True + + self.key = self.get_cache_key(request) + if self.key is None: + return True + + self.history = self.cache.get(self.key, []) + self.now = self.timer() + + # Drop any requests from the history which have now passed the + # throttle duration + while self.history and self.history[-1] <= self.now - self.duration: + self.history.pop() + if len(self.history) >= self.num_requests: + return self.throttle_failure() + return self.throttle_success() + + def throttle_success(self): + """ + Inserts the current request's timestamp along with the key + into the cache. + """ + self.history.insert(0, self.now) + self.cache.set(self.key, self.history, self.duration) + return True + + def throttle_failure(self): + """ + Called when a request to the API has failed due to throttling. + """ + return False + + def wait(self): + """ + Returns the recommended next request time in seconds. + """ + if self.history: + remaining_duration = self.duration - (self.now - self.history[-1]) + else: + remaining_duration = self.duration + + available_requests = self.num_requests - len(self.history) + 1 + if available_requests <= 0: + return None + + return remaining_duration / float(available_requests) + + +class AnonRateThrottle(SimpleRateThrottle): + """ + Limits the rate of API calls that may be made by a anonymous users. + + The IP address of the request will be used as the unique cache key. + """ + + scope = "anon" + + def get_cache_key(self, request): + if request.auth is not None: + return None # Only throttle unauthenticated requests. + + return self.cache_format % { + "scope": self.scope, + "ident": self.get_ident(request), + } + + +class AuthRateThrottle(SimpleRateThrottle): + """ + Limits the rate of API calls that may be made by a given user. + + The string representation of request.auth object will be used as a unique cache key. + If you use custom auth objects make sure to implement __str__ method. + For anonymous requests, the IP address of the request will be used. + """ + + scope = "auth" + + def get_cache_key(self, request): + if request.auth is not None: + ident = str(request.auth) # TODO: maybe auth should have an attribute that developer can overwrite + else: + ident = self.get_ident(request) + + return self.cache_format % {"scope": self.scope, "ident": ident} + + +class UserRateThrottle(SimpleRateThrottle): + """ + Limits the rate of API calls that may be made by a given user. + + The user id will be used as a unique cache key if the user is + authenticated. For anonymous requests, the IP address of the request will + be used. + """ + + scope = 'user' + + def get_cache_key(self, request): + if request.user and request.user.is_authenticated: + ident = request.user.pk + else: + ident = self.get_ident(request) + + return self.cache_format % {'scope': self.scope, 'ident': ident} diff --git a/tests/test_throttling.py b/tests/test_throttling.py new file mode 100644 index 000000000..f38570ea2 --- /dev/null +++ b/tests/test_throttling.py @@ -0,0 +1,279 @@ +import pytest +from django.core.cache import cache +from django.core.exceptions import ImproperlyConfigured + +from ninja import NinjaAPI, Router +from ninja.testing import TestClient +from ninja.throttling import ( + AnonRateThrottle, + AuthRateThrottle, + BaseThrottle, + SimpleRateThrottle, + UserRateThrottle, +) + + +@pytest.fixture(autouse=True) +def clear_cache_for_every_case(): + cache.clear() + + +def test_global_throttling(): + th = AnonRateThrottle("1/s") + set_throttle_timer(th, 0) + + api = NinjaAPI(throttle=[th]) + + @api.get("/check") + def check(request): + return "OK" + + client = TestClient(api) + + resp = client.get("/check") + assert resp.status_code == 200 + assert resp.content == b'"OK"' + + resp = client.get("/check") + assert resp.status_code == 429 + assert resp.json() == {"detail": "Too many requests."} + + set_throttle_timer(th, 2) + resp = client.get("/check") + assert resp.status_code == 200 + assert resp.content == b'"OK"' + + +def test_router_throttling(): + th = AnonRateThrottle("1/s") + set_throttle_timer(th, 0) + + api = NinjaAPI() + router = Router() + + @router.get("/check") + def check(request): + return "OK" + + api.add_router("/router", router, throttle=th) + + client = TestClient(api) + + resp = client.get("/router/check") + assert resp.status_code == 200 + assert resp.content == b'"OK"' + + resp = client.get("/router/check") + assert resp.status_code == 429 + assert resp.json() == {"detail": "Too many requests."} + + +def test_router2_throttling(): + "Here we test that child router inherits the throttling from api instance" + th = AnonRateThrottle("1/s") + set_throttle_timer(th, 0) + + api = NinjaAPI(throttle=th) + router = Router() + + @router.get("/check") + def check(request): + return "OK" + + api.add_router("/router", router) + + client = TestClient(api) + + resp = client.get("/router/check") + assert resp.status_code == 200 + assert resp.content == b'"OK"' + + resp = client.get("/router/check") + assert resp.status_code == 429 + assert resp.json() == {"detail": "Too many requests."} + + +def test_operation_throttling(): + th = AnonRateThrottle("1/s") + set_throttle_timer(th, 0) + + api = NinjaAPI() + + @api.get("/check1", throttle=th) + def check(request): + return "OK" + + client = TestClient(api) + + resp = client.get("/check1") + assert resp.status_code == 200 + assert resp.content == b'"OK"' + + resp = client.get("/check1") + assert resp.status_code == 429 + assert resp.json() == {"detail": "Too many requests."} + + +# "Unit tests" for the throttling module + +_client = TestClient(NinjaAPI()) + + +def build_request(addr="8.8.8.8", x_forwarded_for=None): + "Creates a mock request with the given address and optional X-Forwarded-For header." + meta = {"REMOTE_ADDR": addr} + if x_forwarded_for: + meta["HTTP_X_FORWARDED_FOR"] = x_forwarded_for + return _client._build_request("GET", "/", {}, {"META": meta}) + + +def test_throttle_anon(): + th = AnonRateThrottle("1/s") + set_throttle_timer(th, 0) + + request = build_request() + request.auth = None + + assert th.allow_request(request) is True + assert th.wait() == 1.0 + assert th.get_cache_key(request) == "throttle_anon_8.8.8.8" + + # Next should not allow as it's within the same second + assert th.allow_request(request) is False + + # For auth request it should always allowed + request.auth = "some" + assert th.allow_request(request) is True + assert th.allow_request(request) is True + assert th.allow_request(request) is True + assert th.get_cache_key(request) is None + + +def test_throttle_auth(): + th = AuthRateThrottle("1/s") + set_throttle_timer(th, 0) + + request = build_request() + request.auth = None + + assert th.allow_request(request) is True + assert th.allow_request(request) is False + + request.auth = "some" + assert th.allow_request(request) is True + assert th.allow_request(request) is False + + set_throttle_timer(th, 2) + assert th.allow_request(request) is True + + assert th.get_cache_key(request) == "throttle_auth_some" + + +def test_throttle_user(): + th = UserRateThrottle("1/s") + set_throttle_timer(th, 0) + + request = build_request() + request.user.is_authenticated = True + request.user.pk = 123 + + assert th.allow_request(request) is True + assert th.allow_request(request) is False + + set_throttle_timer(th, 2) + assert th.allow_request(request) is True + + assert th.get_cache_key(request) == "throttle_user_123" + + # Not authenticated user: + request.user.is_authenticated = False + assert th.allow_request(request) is True + assert th.allow_request(request) is False + assert ( + th.get_cache_key(request) == "throttle_user_8.8.8.8" + ) # not authenticated throttled by IP + + +def test_wait(): + th = AuthRateThrottle("5/m") + set_throttle_timer(th, 0) + + request = build_request() + request.auth = None + + for _i in range(5): + assert th.allow_request(request) is True + + assert th.allow_request(request) is False + assert th.wait() == 60 + + set_throttle_timer(th, 30) + assert th.allow_request(request) is False + assert th.wait() == 30 + + # Simulating cache expiration/reset + th.history = [] + # cache.clear() + set_throttle_timer(th, 0) + assert th.wait() == 10 # 60s / 6 available + + # Simulating larger history + th.history = [0] * 10 + th.now = 0 + assert th.wait() is None # available becomes negative + + +def test_rate_parser(): + th = SimpleRateThrottle("1/s") + assert th.parse_rate(None) == (None, None) + assert th.parse_rate("1/s") == (1, 1) + assert th.parse_rate("5/m") == (5, 60) + assert th.parse_rate("10/h") == (10, 3600) + assert th.parse_rate("100/d") == (100, 86400) + + +def test_proxy_throttle(): + from ninja.conf import settings + + settings.NUM_PROXIES = 0 # instead of None + + th = SimpleRateThrottle("1/s") + request = build_request(x_forwarded_for=None) + assert th.get_ident(request) == "8.8.8.8" + + settings.NUM_PROXIES = 0 + request = build_request(x_forwarded_for="8.8.8.8,127.0.0.1") + assert th.get_ident(request) == "8.8.8.8" + + settings.NUM_PROXIES = 1 + assert th.get_ident(request) == "127.0.0.1" + + settings.NUM_PROXIES = None + + +def test_base_classes(): + base = BaseThrottle() + with pytest.raises(NotImplementedError): + base.allow_request(build_request()) + assert base.wait() is None + + sample = SimpleRateThrottle("1/s") + with pytest.raises(NotImplementedError): + sample.allow_request(build_request()) + + throttle = AnonRateThrottle() + with pytest.raises(ImproperlyConfigured): + throttle.scope = None + throttle.get_rate() + + sample_scope2 = SimpleRateThrottle("1/s") + sample_scope2.scope = "scope2" + with pytest.raises(ImproperlyConfigured): + sample_scope2.get_rate() + + +def set_throttle_timer(throttle: BaseThrottle, value: int): + """ + Explicitly set the timer, overriding time.time() + """ + throttle.timer = lambda: value From e34e0796b3f16e89c66bc8ac1776e581b192f517 Mon Sep 17 00:00:00 2001 From: Vitaliy Kucheryaviy Date: Wed, 26 Jun 2024 16:01:12 +0300 Subject: [PATCH 2/7] WIP throttling --- ninja/conf.py | 8 ++++++-- ninja/throttling.py | 8 +++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ninja/conf.py b/ninja/conf.py index d402cdcc1..992322aef 100644 --- a/ninja/conf.py +++ b/ninja/conf.py @@ -7,7 +7,9 @@ class Settings(BaseModel): # Pagination - PAGINATION_CLASS: str = Field("ninja.pagination.LimitOffsetPagination", alias="NINJA_PAGINATION_CLASS") + PAGINATION_CLASS: str = Field( + "ninja.pagination.LimitOffsetPagination", alias="NINJA_PAGINATION_CLASS" + ) PAGINATION_PER_PAGE: int = Field(100, alias="NINJA_PAGINATION_PER_PAGE") PAGINATION_MAX_LIMIT: int = Field(inf, alias="NINJA_PAGINATION_MAX_LIMIT") @@ -29,4 +31,6 @@ class Config: settings = Settings.model_validate(django_settings) if hasattr(django_settings, "NINJA_DOCS_VIEW"): - raise Exception("NINJA_DOCS_VIEW is removed. Use NinjaAPI(docs=...) instead") # pragma: no cover + raise Exception( + "NINJA_DOCS_VIEW is removed. Use NinjaAPI(docs=...) instead" + ) # pragma: no cover diff --git a/ninja/throttling.py b/ninja/throttling.py index 5a59ff1f2..b0e3ca2d5 100644 --- a/ninja/throttling.py +++ b/ninja/throttling.py @@ -199,7 +199,9 @@ class AuthRateThrottle(SimpleRateThrottle): def get_cache_key(self, request): if request.auth is not None: - ident = str(request.auth) # TODO: maybe auth should have an attribute that developer can overwrite + ident = str( + request.auth + ) # TODO: maybe auth should have an attribute that developer can overwrite else: ident = self.get_ident(request) @@ -215,7 +217,7 @@ class UserRateThrottle(SimpleRateThrottle): be used. """ - scope = 'user' + scope = "user" def get_cache_key(self, request): if request.user and request.user.is_authenticated: @@ -223,4 +225,4 @@ def get_cache_key(self, request): else: ident = self.get_ident(request) - return self.cache_format % {'scope': self.scope, 'ident': ident} + return self.cache_format % {"scope": self.scope, "ident": ident} From 4d35929f9e1770034fe0e1a505dd36ad7f954115 Mon Sep 17 00:00:00 2001 From: Vitaliy Kucheryaviy Date: Wed, 26 Jun 2024 17:18:41 +0300 Subject: [PATCH 3/7] hashing throttle key for auth --- ninja/throttling.py | 6 +++--- tests/test_throttling.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ninja/throttling.py b/ninja/throttling.py index b0e3ca2d5..0083cec06 100644 --- a/ninja/throttling.py +++ b/ninja/throttling.py @@ -2,6 +2,7 @@ Provides various throttling policies. """ +import hashlib import time from django.core.cache import cache as default_cache @@ -199,9 +200,8 @@ class AuthRateThrottle(SimpleRateThrottle): def get_cache_key(self, request): if request.auth is not None: - ident = str( - request.auth - ) # TODO: maybe auth should have an attribute that developer can overwrite + ident = hashlib.sha256(str(request.auth).encode()).hexdigest() + # TODO: ^maybe auth should have an attribute that developer can overwrite else: ident = self.get_ident(request) diff --git a/tests/test_throttling.py b/tests/test_throttling.py index f38570ea2..e3c93b660 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -166,7 +166,10 @@ def test_throttle_auth(): set_throttle_timer(th, 2) assert th.allow_request(request) is True - assert th.get_cache_key(request) == "throttle_auth_some" + assert ( + th.get_cache_key(request) + == "throttle_auth_a6b46dd0d1ae5e86cbc8f37e75ceeb6760230c1ca4ffbcb0c97b96dd7d9c464b" + ) def test_throttle_user(): From 3db8d587b4079d77063907344a59400f617c5e11 Mon Sep 17 00:00:00 2001 From: Vitaliy Kucheryaviy Date: Wed, 26 Jun 2024 17:18:47 +0300 Subject: [PATCH 4/7] docs --- docs/docs/guides/throttling.md | 92 ++++++++++++++++++++++++++++++++++ docs/mkdocs.yml | 1 + 2 files changed, 93 insertions(+) create mode 100644 docs/docs/guides/throttling.md diff --git a/docs/docs/guides/throttling.md b/docs/docs/guides/throttling.md new file mode 100644 index 000000000..941e6e0c6 --- /dev/null +++ b/docs/docs/guides/throttling.md @@ -0,0 +1,92 @@ +# Throttling + +Throttles allows to control the rate of requests that clients can make to an API. Django Ninja allows to set custom throttlers globally (across all operations in NinjaAPI instance), on router level and each operation individually. + +!!! note + The application-level throttling that Django Ninja provides should not be considered a security measure or protection against brute forcing or denial-of-service attacks. Deliberately malicious actors will always be able to spoof IP origins. The built-in throttling implementations are implemented using Django's cache framework, and use non-atomic operations to determine the request rate, which may sometimes result in some fuzziness. + + +Django Ninja’s throttling feature is pretty much based on what Django Rest Framework (DRF) uses, which you can check out [here](https://www.django-rest-framework.org/api-guide/throttling/). So, if you’ve already got custom throttling set up for DRF, there’s a good chance it’ll work with Django Ninja right out of the box. They key difference is that you need to pass initialized Throttle objects instead of classes (which should give a better performance) + + +## Usage + +### Global + +The following example will limit unauthenticated users to only 10 requests per second, while authenticated can make 100/s + +```Python +from ninja.throttling import AnonRateThrottle, AuthRateThrottle + +api = NinjaAPI + throttle=[ + AnonRateThrottle('10/s'), + AuthRateThrottle('100/s'), + ], +) +``` + +!!! tip + `throttle` argument accepts single object and list of throttle objects + +### Router level + +```Python +api = NinjaAPI() +... + +api.add_router('/sensitive', 'myapp.api.router', throttle=AnonRateThrottle('100/m')) +``` + +or + +```Python +router = Router(..., throttle=[AnonRateThrottle('1000/h')]) +``` + + +### Operation level + + +```Python +from ninja.throttling import UserRateThrottle + +@api.get('/some', throttle=[UserRateThrottle('10000/d')]) +def some(request): + ... +``` + +## Builtin throttlers + +### AnonRateThrottle + +Will only throttle unauthenticated users. The IP address of the incoming request is used to generate a unique key to throttle against. + + +### UserRateThrottle + +Will throttle users (**if you use django build-in user authentication**) to a given rate of requests across the API. The user id is used to generate a unique key to throttle against. Unauthenticated requests will fall back to using the IP address of the incoming request to generate a unique key to throttle against. + +### AuthRateThrottle + +Will throttle by Django ninja [authentication](guides/authentication.md) to a given rate of requests across the API. Unauthenticated requests will fall back to using the IP address of the incoming request to generate a unique key to throttle against. + +Note: the cache key in case of `request.auth` will be generated by `sha256(str(request.auth))` - so if you returning some custom objects inside authentication make sure to implement `__str__` method that will return a unique value for the user. + + +## Custom throttles +To create a custom throttle, override `BaseThrottle` (or any of builtin throttles) and implement `.allow_request(self, request)`. The method should return `True` if the request should be allowed, and `False` otherwise. + +Example + +```Python +from ninja.throttling import AnonRateThrottle + +class NoReadsThrottle(AnonRateThrottle): + """Do not throttle GET requests""" + + def allow_request(self, request): + if request.method == "GET": + return True + return super().allow_request(request) +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 51f7be46b..fc06094d3 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -69,6 +69,7 @@ nav: - guides/response/response-renderers.md - Splitting your API with Routers: guides/routers.md - guides/authentication.md + - guides/throttling.md - guides/testing.md - guides/api-docs.md - guides/errors.md From 06056fd70ed191503b0c9ec875dfd73107ff12c4 Mon Sep 17 00:00:00 2001 From: Vitaliy Kucheryaviy Date: Wed, 26 Jun 2024 17:58:01 +0300 Subject: [PATCH 5/7] mypy --- ninja/errors.py | 4 +-- ninja/operation.py | 20 +++++++-------- ninja/router.py | 2 +- ninja/throttling.py | 61 ++++++++++++++++++++++----------------------- 4 files changed, 42 insertions(+), 45 deletions(-) diff --git a/ninja/errors.py b/ninja/errors.py index f938e0635..95a130085 100644 --- a/ninja/errors.py +++ b/ninja/errors.py @@ -1,7 +1,7 @@ import logging import traceback from functools import partial -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from django.conf import settings from django.http import Http404, HttpRequest, HttpResponse @@ -54,7 +54,7 @@ def __str__(self) -> str: class Throttled(HttpError): - def __init__(self, wait: int) -> None: + def __init__(self, wait: Optional[int]) -> None: self.wait = wait super().__init__(status_code=429, message="Too many requests.") diff --git a/ninja/operation.py b/ninja/operation.py index cdb407cfb..6b017c1ed 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -73,7 +73,7 @@ def __init__( self.throttle_param = throttle self.throttle_objects: List[BaseThrottle] = [] if throttle is not NOT_SET: - for th in throttle: + for th in throttle: # type: ignore assert isinstance( th, BaseThrottle ), "Throttle should be an instance of BaseThrottle" @@ -140,14 +140,11 @@ def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None: self.throttle_objects = ( isinstance(api.throttle, BaseThrottle) and [api.throttle] - or api.throttle + or api.throttle # type: ignore ) if router.throttle != NOT_SET: - self.throttle_objects = ( - isinstance(router.throttle, BaseThrottle) - and [router.throttle] - or router.throttle - ) + _t = router.throttle + self.throttle_objects = isinstance(_t, BaseThrottle) and [_t] or _t # type: ignore assert all( isinstance(th, BaseThrottle) for th in self.throttle_objects ), "Throttle should be an instance of BaseThrottle" @@ -173,13 +170,13 @@ def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: # auth: if self.auth_callbacks: - error = self._run_authentication(request) + error = self._run_authentication(request) # type: ignore if error: return error # Throttling: if self.throttle_objects: - error = self._check_throttles(request) + error = self._check_throttles(request) # type: ignore if error: return error @@ -213,7 +210,8 @@ def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]: ] duration = max(durations, default=None) - return self.api.on_exception(request, Throttled(wait=duration)) + return self.api.on_exception(request, Throttled(wait=duration)) # type: ignore + return None def _result_to_response( self, request: HttpRequest, result: Any, temporal_response: HttpResponse @@ -384,7 +382,7 @@ def add_operation( methods: List[str], view_func: Callable, *, - auth: Optional[Union[Sequence[Callable], Callable, object]] = NOT_SET, + auth: Optional[Union[Sequence[Callable], Callable, NOT_SET_TYPE]] = NOT_SET, throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, diff --git a/ninja/router.py b/ninja/router.py index 72ade271d..c1d1372ad 100644 --- a/ninja/router.py +++ b/ninja/router.py @@ -37,7 +37,7 @@ def __init__( ) -> None: self.api: Optional["NinjaAPI"] = None self.auth = auth - self.throttle: List[BaseThrottle] = throttle + self.throttle = throttle self.tags = tags self.path_operations: Dict[str, PathView] = {} self._routers: List[Tuple[str, Router]] = [] diff --git a/ninja/throttling.py b/ninja/throttling.py index 0083cec06..728b6b988 100644 --- a/ninja/throttling.py +++ b/ninja/throttling.py @@ -1,12 +1,10 @@ -""" -Provides various throttling policies. -""" - import hashlib import time +from typing import Dict, List, Optional, Tuple from django.core.cache import cache as default_cache from django.core.exceptions import ImproperlyConfigured +from django.http import HttpRequest class BaseThrottle: @@ -14,13 +12,13 @@ class BaseThrottle: Rate throttling of requests. """ - def allow_request(self, request): + def allow_request(self, request: HttpRequest) -> bool: """ Return `True` if the request should be allowed, `False` otherwise. """ raise NotImplementedError(".allow_request() must be overridden") - def get_ident(self, request): + def get_ident(self, request: HttpRequest) -> Optional[str]: """ Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR if present and number of proxies is > 0. If not use all of @@ -35,13 +33,13 @@ def get_ident(self, request): if num_proxies is not None: if num_proxies == 0 or xff is None: return remote_addr - addrs = xff.split(",") + addrs: List[str] = xff.split(",") client_addr = addrs[-min(num_proxies, len(addrs))] return client_addr.strip() return "".join(xff.split()) if xff else remote_addr - def wait(self): + def wait(self) -> Optional[float]: """ Optionally, return a recommended number of seconds to wait before the next request. @@ -67,17 +65,18 @@ class SimpleRateThrottle(BaseThrottle): cache = default_cache timer = time.time cache_format = "throttle_%(scope)s_%(ident)s" - scope = None - THROTTLE_RATES = settings.DEFAULT_THROTTLE_RATES + scope: Optional[str] = None + THROTTLE_RATES: Dict[str, Optional[str]] = settings.DEFAULT_THROTTLE_RATES - def __init__(self, rate=None): + def __init__(self, rate: Optional[str] = None): + self.rate: Optional[str] if rate: self.rate = rate else: self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) - def get_cache_key(self, request): + def get_cache_key(self, request: HttpRequest) -> Optional[str]: """ Should return a unique cache-key which can be used for throttling. Must be overridden. @@ -86,7 +85,7 @@ def get_cache_key(self, request): """ raise NotImplementedError(".get_cache_key() must be overridden") - def get_rate(self): + def get_rate(self) -> Optional[str]: """ Determine the string representation of the allowed request rate. """ @@ -95,12 +94,12 @@ def get_rate(self): raise ImproperlyConfigured(msg) try: - return self.THROTTLE_RATES[self.scope] + return self.THROTTLE_RATES[self.scope] # type: ignore except KeyError: msg = f"No default throttle rate set for '{self.scope}' scope" raise ImproperlyConfigured(msg) from None - def parse_rate(self, rate): + def parse_rate(self, rate: Optional[str]) -> Tuple[Optional[int], Optional[int]]: """ Given the request rate string, return a two tuple of: , @@ -112,7 +111,7 @@ def parse_rate(self, rate): duration = {"s": 1, "m": 60, "h": 3600, "d": 86400}[period[0]] return (num_requests, duration) - def allow_request(self, request): + def allow_request(self, request: HttpRequest) -> bool: """ Implement the check to see if the request should be throttled. @@ -127,17 +126,17 @@ def allow_request(self, request): return True self.history = self.cache.get(self.key, []) - self.now = self.timer() + self.now = self.timer() # type: ignore # Drop any requests from the history which have now passed the # throttle duration - while self.history and self.history[-1] <= self.now - self.duration: + while self.history and self.history[-1] <= self.now - self.duration: # type: ignore self.history.pop() - if len(self.history) >= self.num_requests: + if len(self.history) >= self.num_requests: # type: ignore return self.throttle_failure() return self.throttle_success() - def throttle_success(self): + def throttle_success(self) -> bool: """ Inserts the current request's timestamp along with the key into the cache. @@ -146,13 +145,13 @@ def throttle_success(self): self.cache.set(self.key, self.history, self.duration) return True - def throttle_failure(self): + def throttle_failure(self) -> bool: """ Called when a request to the API has failed due to throttling. """ return False - def wait(self): + def wait(self) -> Optional[float]: """ Returns the recommended next request time in seconds. """ @@ -161,11 +160,11 @@ def wait(self): else: remaining_duration = self.duration - available_requests = self.num_requests - len(self.history) + 1 + available_requests = self.num_requests - len(self.history) + 1 # type: ignore if available_requests <= 0: return None - return remaining_duration / float(available_requests) + return remaining_duration / float(available_requests) # type: ignore class AnonRateThrottle(SimpleRateThrottle): @@ -177,8 +176,8 @@ class AnonRateThrottle(SimpleRateThrottle): scope = "anon" - def get_cache_key(self, request): - if request.auth is not None: + def get_cache_key(self, request: HttpRequest) -> Optional[str]: + if request.auth is not None: # type: ignore return None # Only throttle unauthenticated requests. return self.cache_format % { @@ -198,12 +197,12 @@ class AuthRateThrottle(SimpleRateThrottle): scope = "auth" - def get_cache_key(self, request): - if request.auth is not None: - ident = hashlib.sha256(str(request.auth).encode()).hexdigest() + def get_cache_key(self, request: HttpRequest) -> str: + if request.auth is not None: # type: ignore + ident = hashlib.sha256(str(request.auth).encode()).hexdigest() # type: ignore # TODO: ^maybe auth should have an attribute that developer can overwrite else: - ident = self.get_ident(request) + ident = self.get_ident(request) # type: ignore return self.cache_format % {"scope": self.scope, "ident": ident} @@ -219,7 +218,7 @@ class UserRateThrottle(SimpleRateThrottle): scope = "user" - def get_cache_key(self, request): + def get_cache_key(self, request: HttpRequest) -> str: if request.user and request.user.is_authenticated: ident = request.user.pk else: From 023e24440632b447374348d714bf02663564949b Mon Sep 17 00:00:00 2001 From: Vitaliy Kucheryaviy Date: Wed, 26 Jun 2024 18:03:39 +0300 Subject: [PATCH 6/7] Throttle docs fixes --- docs/docs/guides/throttling.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/docs/guides/throttling.md b/docs/docs/guides/throttling.md index 941e6e0c6..56614bdde 100644 --- a/docs/docs/guides/throttling.md +++ b/docs/docs/guides/throttling.md @@ -18,7 +18,7 @@ The following example will limit unauthenticated users to only 10 requests per s ```Python from ninja.throttling import AnonRateThrottle, AuthRateThrottle -api = NinjaAPI +api = NinjaAPI( throttle=[ AnonRateThrottle('10/s'), AuthRateThrottle('100/s'), @@ -31,6 +31,8 @@ api = NinjaAPI ### Router level +Pass `throttle` argument either to `add_router` function + ```Python api = NinjaAPI() ... @@ -38,7 +40,7 @@ api = NinjaAPI() api.add_router('/sensitive', 'myapp.api.router', throttle=AnonRateThrottle('100/m')) ``` -or +or directly to init of the Router class: ```Python router = Router(..., throttle=[AnonRateThrottle('1000/h')]) @@ -47,6 +49,7 @@ router = Router(..., throttle=[AnonRateThrottle('1000/h')]) ### Operation level +If `throttle` argument is passed to operation - it will overrule all global and router throttles: ```Python from ninja.throttling import UserRateThrottle From 62888c743bd2f0dc9aaf165b403b1534a63f419f Mon Sep 17 00:00:00 2001 From: Vitaliy Kucheryaviy Date: Wed, 26 Jun 2024 18:37:16 +0300 Subject: [PATCH 7/7] Throttling - fixed throttle when auth not used --- ninja/throttling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ninja/throttling.py b/ninja/throttling.py index 728b6b988..640f1041b 100644 --- a/ninja/throttling.py +++ b/ninja/throttling.py @@ -177,7 +177,7 @@ class AnonRateThrottle(SimpleRateThrottle): scope = "anon" def get_cache_key(self, request: HttpRequest) -> Optional[str]: - if request.auth is not None: # type: ignore + if getattr(request, "auth", None) is not None: return None # Only throttle unauthenticated requests. return self.cache_format % { @@ -198,7 +198,7 @@ class AuthRateThrottle(SimpleRateThrottle): scope = "auth" def get_cache_key(self, request: HttpRequest) -> str: - if request.auth is not None: # type: ignore + if getattr(request, "auth", None) is not None: ident = hashlib.sha256(str(request.auth).encode()).hexdigest() # type: ignore # TODO: ^maybe auth should have an attribute that developer can overwrite else: