Skip to content

Commit

Permalink
feat: Use personal api keys for throttling instead of teams (#24039)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Matloka <[email protected]>
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 20, 2024
1 parent c333823 commit 3020152
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 75 deletions.
14 changes: 12 additions & 2 deletions frontend/src/scenes/settings/user/PersonalAPIKeys.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { LemonField } from 'lib/lemon-ui/LemonField'
import { capitalizeFirstLetter, humanFriendlyDetailedTime } from 'lib/utils'
import { Fragment, useEffect } from 'react'

import { API_KEY_SCOPE_PRESETS, APIScopes, personalAPIKeysLogic } from './personalAPIKeysLogic'
import { API_KEY_SCOPE_PRESETS, APIScopes, MAX_API_KEYS_PER_USER, personalAPIKeysLogic } from './personalAPIKeysLogic'

function EditKeyModal(): JSX.Element {
const {
Expand Down Expand Up @@ -468,6 +468,7 @@ function PersonalAPIKeysTable(): JSX.Element {
}

export function PersonalAPIKeys(): JSX.Element {
const { keys } = useValues(personalAPIKeysLogic)
const { setEditingKeyId } = useActions(personalAPIKeysLogic)

return (
Expand All @@ -484,7 +485,16 @@ export function PersonalAPIKeys(): JSX.Element {
More about API authentication in PostHog Docs.
</Link>
</p>
<LemonButton type="primary" icon={<IconPlus />} onClick={() => setEditingKeyId('new')}>
<LemonButton
type="primary"
icon={<IconPlus />}
onClick={() => setEditingKeyId('new')}
disabledReason={
keys.length >= MAX_API_KEYS_PER_USER
? `You can only have ${MAX_API_KEYS_PER_USER} personal API keys. Remove an existing key before creating a new one.`
: false
}
>
Create personal API key
</LemonButton>

Expand Down
2 changes: 2 additions & 0 deletions frontend/src/scenes/settings/user/personalAPIKeysLogic.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import { OrganizationBasicType, PersonalAPIKeyType, TeamBasicType } from '~/type

import type { personalAPIKeysLogicType } from './personalAPIKeysLogicType'

export const MAX_API_KEYS_PER_USER = 10 // Same as in posthog/api/personal_api_key.py

export const API_KEY_SCOPE_PRESETS = [
{ value: 'local_evaluation', label: 'Local feature flag evaluation', scopes: ['feature_flag:read'] },
{
Expand Down
8 changes: 8 additions & 0 deletions posthog/api/personal_api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from posthog.user_permissions import UserPermissions


MAX_API_KEYS_PER_USER = 10 # Same as in personalAPIKeysLogic.tsx


class PersonalAPIKeySerializer(serializers.ModelSerializer):
# Specifying method name because the serializer class already has a get_value method
value = serializers.SerializerMethodField(method_name="get_key_value", read_only=True)
Expand Down Expand Up @@ -93,6 +96,11 @@ def to_representation(self, instance):

def create(self, validated_data: dict, **kwargs) -> PersonalAPIKey:
user = self.context["request"].user
count = PersonalAPIKey.objects.filter(user=user).count()
if count >= MAX_API_KEYS_PER_USER:
raise serializers.ValidationError(
f"You can only have {MAX_API_KEYS_PER_USER} personal API keys. Remove an existing key before creating a new one."
)
value = generate_random_token_personal()
mask_value = mask_key_value(value)
secure_value = hash_key_value(value)
Expand Down
8 changes: 2 additions & 6 deletions posthog/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@
from posthog.hogql.errors import ExposedHogQLError
from posthog.hogql_queries.query_runner import ExecutionMode, execution_mode_from_refresh
from posthog.models.user import User
from posthog.rate_limit import (
AIBurstRateThrottle,
AISustainedRateThrottle,
TeamRateThrottle,
)
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle, PersonalApiKeyRateThrottle
from posthog.schema import QueryRequest, QueryResponseAlternative, QueryStatusResponse


class QueryThrottle(TeamRateThrottle):
class QueryThrottle(PersonalApiKeyRateThrottle):
scope = "query"
rate = "120/hour"

Expand Down
1 change: 1 addition & 0 deletions posthog/api/test/test_feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3773,6 +3773,7 @@ def test_rate_limits_for_local_evaluation_are_independent(self, rate_limit_enabl
"scope": "burst",
"rate": "5/minute",
"path": f"/api/projects/TEAM_ID/feature_flags",
"hashed_personal_api_key": hash_key_value(personal_api_key),
},
)

Expand Down
2 changes: 2 additions & 0 deletions posthog/api/test/test_person.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def test_rate_limits_for_persons_are_independent(self, rate_limit_enabled_mock,
"scope": "burst",
"rate": "5/minute",
"path": f"/api/projects/TEAM_ID/feature_flags",
"hashed_personal_api_key": hash_key_value(personal_api_key),
},
)

Expand Down Expand Up @@ -911,6 +912,7 @@ def test_rate_limits_for_persons_are_independent(self, rate_limit_enabled_mock,
"scope": "persons",
"rate": "6/minute",
"path": f"/api/projects/TEAM_ID/persons/",
"hashed_personal_api_key": hash_key_value(personal_api_key),
},
)

Expand Down
18 changes: 18 additions & 0 deletions posthog/api/test/test_personal_api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ def test_create_personal_api_key(self):
}
assert data["value"].startswith("phx_") # Personal API key prefix

def test_create_too_many_api_keys(self):
for i in range(0, 10):
self.client.post(
"/api/personal_api_keys",
{"label": i, "scopes": ["insight:read"], "scoped_organizations": [], "scoped_teams": []},
)
response = self.client.post(
"/api/personal_api_keys",
{"label": i, "scopes": ["insight:read"], "scoped_organizations": [], "scoped_teams": []},
)
assert response.status_code == 400
assert response.json() == {
"type": "validation_error",
"code": "invalid_input",
"detail": "You can only have 10 personal API keys. Remove an existing key before creating a new one.",
"attr": None,
}

def test_create_personal_api_key_label_required(self):
response = self.client.post("/api/personal_api_keys/", {"label": ""})
assert response.status_code == 400
Expand Down
110 changes: 60 additions & 50 deletions posthog/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from posthog.models.instance_setting import get_instance_setting
from posthog.settings.utils import get_list
from token_bucket import Limiter, MemoryStorage
from posthog.models.personal_api_key import hash_key_value


RATE_LIMIT_EXCEEDED_COUNTER = Counter(
Expand Down Expand Up @@ -68,7 +69,7 @@ def is_decide_rate_limit_enabled() -> bool:
path_by_org_pattern = re.compile(r"/api/organizations/(.+)/")


class TeamRateThrottle(SimpleRateThrottle):
class PersonalApiKeyRateThrottle(SimpleRateThrottle):
@staticmethod
def safely_get_team_id_from_view(view):
"""
Expand All @@ -87,63 +88,72 @@ def allow_request(self, request, view):
return True

# Only rate limit authenticated requests made with a personal API key
if request.user.is_authenticated and PersonalAPIKeyAuthentication.find_key_with_source(request) is None:
personal_api_key = PersonalAPIKeyAuthentication.find_key_with_source(request)
if request.user.is_authenticated and personal_api_key is None:
return True

# As we're figuring out what our throttle limits should be, we don't actually want to throttle anything.
# Instead of throttling, this logs that the request would have been throttled.
try:
request_would_be_allowed = super().allow_request(request, view)
if not request_would_be_allowed:
team_id = self.safely_get_team_id_from_view(view)
path = getattr(request, "path", None)
if path:
path = path_by_team_pattern.sub("/api/projects/TEAM_ID/", path)
path = path_by_org_pattern.sub("/api/organizations/ORG_ID/", path)

if self.team_is_allowed_to_bypass_throttle(team_id):
statsd.incr(
"team_allowed_to_bypass_rate_limit_exceeded",
tags={"team_id": team_id, "path": path},
)
RATE_LIMIT_BYPASSED_COUNTER.labels(team_id=team_id, path=path).inc()
return True
else:
scope = getattr(self, "scope", None)
rate = getattr(self, "rate", None)

statsd.incr(
"rate_limit_exceeded",
tags={
"team_id": team_id,
"scope": scope,
"rate": rate,
"path": path,
},
)
RATE_LIMIT_EXCEEDED_COUNTER.labels(team_id=team_id, scope=scope, path=path).inc()

return request_would_be_allowed
if request_would_be_allowed:
return True

team_id = self.safely_get_team_id_from_view(view)
path = getattr(request, "path", None)
if path:
path = path_by_team_pattern.sub("/api/projects/TEAM_ID/", path)
path = path_by_org_pattern.sub("/api/organizations/ORG_ID/", path)

if self.team_is_allowed_to_bypass_throttle(team_id):
statsd.incr(
"team_allowed_to_bypass_rate_limit_exceeded",
tags={"team_id": team_id, "path": path},
)
RATE_LIMIT_BYPASSED_COUNTER.labels(team_id=team_id, path=path).inc()
return True
else:
scope = getattr(self, "scope", None)
rate = getattr(self, "rate", None)

statsd.incr(
"rate_limit_exceeded",
tags={
"team_id": team_id,
"scope": scope,
"rate": rate,
"path": path,
"hashed_personal_api_key": hash_key_value(personal_api_key[0]) if personal_api_key else None,
},
)
RATE_LIMIT_EXCEEDED_COUNTER.labels(team_id=team_id, scope=scope, path=path).inc()

return False
except Exception as e:
capture_exception(e)
return True

def get_cache_key(self, request, view):
"""
Attempts to throttle based on the team_id of the request. If it can't do that, it falls back to the user_id.
And then finally to the IP address.
Tries the following options in order:
- personal_api_key
- team_id
- user_id
- ip
"""
ident = None
if request.user.is_authenticated:
try:
team_id = self.safely_get_team_id_from_view(view)
if team_id:
ident = team_id
else:
ident = request.user.pk
except Exception as e:
capture_exception(e)
ident = self.get_ident(request)
api_key = PersonalAPIKeyAuthentication.find_key_with_source(request)
if api_key is not None:
ident = hash_key_value(api_key[0])
else:
try:
team_id = self.safely_get_team_id_from_view(view)
if team_id:
ident = team_id
else:
ident = request.user.pk
except Exception as e:
capture_exception(e)
ident = self.get_ident(request)
else:
ident = self.get_ident(request)

Expand All @@ -157,7 +167,7 @@ def team_is_allowed_to_bypass_throttle(self, team_id: Optional[int]) -> bool:
class DecideRateThrottle(BaseThrottle):
"""
This is a custom throttle that is used to limit the number of requests to the /decide endpoint.
It is different from the TeamRateThrottle in that it does not use the Django cache, but instead
It is different from the PersonalApiKeyRateThrottle in that it does not use the Django cache, but instead
uses the Limiter from the `token-bucket` library.
This uses the token bucket algorithm to limit the number of requests to the endpoint. It's a lot
more performant than DRF's SimpleRateThrottle, which inefficiently uses the Django cache.
Expand Down Expand Up @@ -243,28 +253,28 @@ def get_cache_key(self, request, view):
return self.cache_format % {"scope": self.scope, "ident": ident}


class BurstRateThrottle(TeamRateThrottle):
class BurstRateThrottle(PersonalApiKeyRateThrottle):
# Throttle class that's applied on all endpoints (except for capture + decide)
# Intended to block quick bursts of requests, per project
scope = "burst"
rate = "480/minute"


class SustainedRateThrottle(TeamRateThrottle):
class SustainedRateThrottle(PersonalApiKeyRateThrottle):
# Throttle class that's applied on all endpoints (except for capture + decide)
# Intended to block slower but sustained bursts of requests, per project
scope = "sustained"
rate = "4800/hour"


class ClickHouseBurstRateThrottle(TeamRateThrottle):
class ClickHouseBurstRateThrottle(PersonalApiKeyRateThrottle):
# Throttle class that's a bit more aggressive and is used specifically on endpoints that hit ClickHouse
# Intended to block quick bursts of requests, per project
scope = "clickhouse_burst"
rate = "240/minute"


class ClickHouseSustainedRateThrottle(TeamRateThrottle):
class ClickHouseSustainedRateThrottle(PersonalApiKeyRateThrottle):
# Throttle class that's a bit more aggressive and is used specifically on endpoints that hit OpenAI
# Intended to block slower but sustained bursts of requests, per project
scope = "clickhouse_sustained"
Expand Down
Loading

0 comments on commit 3020152

Please sign in to comment.