diff --git a/posthog/api/authentication.py b/posthog/api/authentication.py index d06e7168d0df2..9b7dc954a97fa 100644 --- a/posthog/api/authentication.py +++ b/posthog/api/authentication.py @@ -22,7 +22,6 @@ from rest_framework.exceptions import APIException from rest_framework.request import Request from rest_framework.response import Response -from rest_framework.throttling import UserRateThrottle from sentry_sdk import capture_exception from social_django.views import auth from two_factor.utils import default_device @@ -36,14 +35,11 @@ from posthog.email import is_email_available from posthog.event_usage import report_user_logged_in, report_user_password_reset from posthog.models import OrganizationDomain, User +from posthog.rate_limit import UserPasswordResetThrottle from posthog.tasks.email import send_password_reset from posthog.utils import get_instance_available_sso_providers -class UserPasswordResetThrottle(UserRateThrottle): - rate = "6/day" - - @csrf_protect def logout(request): if request.user.is_authenticated: @@ -190,6 +186,7 @@ class LoginViewSet(NonCreatingViewSetMixin, viewsets.GenericViewSet): queryset = User.objects.none() serializer_class = LoginSerializer permission_classes = (permissions.AllowAny,) + # NOTE: Throttling is handled by the `axes` package class TwoFactorSerializer(serializers.Serializer): diff --git a/posthog/api/test/test_authentication.py b/posthog/api/test/test_authentication.py index 3d054e4cb1ac9..a33a59dd0549b 100644 --- a/posthog/api/test/test_authentication.py +++ b/posthog/api/test/test_authentication.py @@ -434,6 +434,23 @@ def test_cant_reset_more_than_six_times(self): # Three emails should be sent, fourth should not self.assertEqual(len(mail.outbox), 6) + def test_is_rate_limited_on_email_not_ip(self): + set_instance_setting("EMAIL_HOST", "localhost") + + for email in ["email@posthog.com", "other-email@posthog.com"]: + for i in range(7): + with self.settings(CELERY_TASK_ALWAYS_EAGER=True, SITE_URL="https://my.posthog.net"): + response = self.client.post("/api/reset/", {"email": email}) + if i < 6: + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + else: + # Fourth request should fail + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + self.assertDictContainsSubset( + {"attr": None, "code": "throttled", "type": "throttled_error"}, + response.json(), + ) + # Token validation def test_can_validate_token(self): diff --git a/posthog/api/user.py b/posthog/api/user.py index 7e72d46b88cb8..28b4a42b8620a 100644 --- a/posthog/api/user.py +++ b/posthog/api/user.py @@ -23,7 +23,7 @@ from rest_framework.exceptions import NotFound from rest_framework.permissions import IsAuthenticated, AllowAny from rest_framework.response import Response -from rest_framework.throttling import UserRateThrottle + from two_factor.forms import TOTPDeviceForm from two_factor.utils import default_device @@ -46,6 +46,7 @@ from posthog.models.organization import Organization from posthog.models.user import NOTIFICATION_DEFAULTS, Notifications from posthog.permissions import APIScopePermission +from posthog.rate_limit import UserAuthenticationThrottle, UserEmailVerificationThrottle from posthog.tasks import user_identify from posthog.tasks.email import send_email_change_emails from posthog.user_permissions import UserPermissions @@ -53,20 +54,6 @@ from posthog.constants import PERMITTED_FORUM_DOMAINS -class UserAuthenticationThrottle(UserRateThrottle): - rate = "5/minute" - - def allow_request(self, request, view): - # only throttle non-GET requests - if request.method == "GET": - return True - return super().allow_request(request, view) - - -class UserEmailVerificationThrottle(UserRateThrottle): - rate = "6/day" - - class ScenePersonalisationBasicSerializer(serializers.ModelSerializer): class Meta: model = UserScenePersonalisation diff --git a/posthog/rate_limit.py b/posthog/rate_limit.py index dbaa478d9f462..856d1b6cceb32 100644 --- a/posthog/rate_limit.py +++ b/posthog/rate_limit.py @@ -1,3 +1,4 @@ +import hashlib import re import time from functools import lru_cache @@ -222,6 +223,26 @@ def get_bucket_key(self, request): return ident +class UserOrEmailRateThrottle(SimpleRateThrottle): + """ + Typically throttling is on the user or the IP address. + For unauthenticated signup/login requests we want to throttle on the email address. + """ + + scope = "user" + + def get_cache_key(self, request, view): + if request.user and request.user.is_authenticated: + ident = request.user.pk + else: + # For unauthenticated requests, we want to throttle on something unique to the user they are trying to work with + # This could be email for example when logging in or uuid when verifying email + ident = request.data.get("email") or request.data.get("uuid") or self.get_ident(request) + ident = hashlib.sha256(ident.encode()).hexdigest() + + return self.cache_format % {"scope": self.scope, "ident": ident} + + class BurstRateThrottle(TeamRateThrottle): # Throttle class that's applied on all endpoints (except for capture + decide) # Intended to block quick bursts of requests, per project @@ -262,3 +283,24 @@ class AISustainedRateThrottle(UserRateThrottle): # Intended to block slower but sustained bursts of requests, per user scope = "ai_sustained" rate = "40/day" + + +class UserPasswordResetThrottle(UserOrEmailRateThrottle): + scope = "user_password_reset" + rate = "6/day" + + +class UserAuthenticationThrottle(UserOrEmailRateThrottle): + scope = "user_authentication" + rate = "5/minute" + + def allow_request(self, request, view): + # only throttle non-GET requests + if request.method == "GET": + return True + return super().allow_request(request, view) + + +class UserEmailVerificationThrottle(UserOrEmailRateThrottle): + scope = "user_email_verification" + rate = "6/day"