Skip to content

Commit

Permalink
feature gate API for org domains
Browse files Browse the repository at this point in the history
  • Loading branch information
raquelmsmith committed Mar 7, 2024
1 parent 53890a2 commit 8b644e6
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 30 deletions.
18 changes: 14 additions & 4 deletions posthog/api/organization_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.cloud_utils import is_cloud
from posthog.constants import AvailableFeature
from posthog.models import OrganizationDomain
from posthog.models.organization import Organization
from posthog.permissions import OrganizationAdminWritePermissions

DOMAIN_REGEX = r"^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$"
Expand Down Expand Up @@ -39,15 +41,16 @@ class Meta:
}

def create(self, validated_data: Dict[str, Any]) -> OrganizationDomain:
organization: Organization = self.context["view"].organization
if is_cloud() and not organization.is_feature_available(AvailableFeature.AUTOMATIC_PROVISIONING):
raise exceptions.PermissionDenied("Automatic provisioning is not enabled for this organization.")
validated_data.pop("jit_provisioning_enabled", None)
validated_data["organization"] = self.context["view"].organization
validated_data.pop(
"jit_provisioning_enabled", None
) # can never be set on creation because domain must be verified
validated_data.pop("sso_enforcement", None) # can never be set on creation because domain must be verified
instance = super().create(validated_data)

if not is_cloud():
instance, _ = instance.attempt_verification()
instance: OrganizationDomain = super().create(validated_data)

return instance

Expand All @@ -66,6 +69,13 @@ def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
{protected_attr: "This attribute cannot be updated until the domain is verified."},
code="verification_required",
)
if instance and attrs.get("jit_provisioning_enabled", None):
organization: Organization = self.context["view"].organization
if not organization.is_feature_available(AvailableFeature.AUTOMATIC_PROVISIONING):
raise serializers.ValidationError(
{"jit_provisioning_enabled": "Automatic provisioning is not enabled for this organization."},
code="feature_not_available",
)

return attrs

Expand Down
28 changes: 8 additions & 20 deletions posthog/api/test/test_organization_domain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import datetime
from unittest.mock import patch
from zoneinfo import ZoneInfo

import dns.resolver
import dns.rrset
import pytest
from zoneinfo import ZoneInfo
from django.utils import timezone
from freezegun import freeze_time
from rest_framework import status
Expand Down Expand Up @@ -87,6 +86,8 @@ def test_cannot_list_or_retrieve_domains_for_other_org(self):

def test_create_domain(self):
self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization.available_features = ["automatic_provisioning"]
self.organization.save()
self.organization_membership.save()

with self.is_cloud(True):
Expand All @@ -113,12 +114,11 @@ def test_create_domain(self):
self.assertEqual(instance.last_verification_retry, None)
self.assertEqual(instance.sso_enforcement, "")

@pytest.mark.skip_on_multitenancy
def test_creating_domain_on_self_hosted_is_automatically_verified(self):
def test_cant_create_domain_without_feature(self):
self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization_membership.save()

with freeze_time("2021-08-08T20:20:08Z"):
with self.is_cloud(True):
response = self.client.post(
"/api/organizations/@current/domains/",
{
Expand All @@ -129,21 +129,7 @@ def test_creating_domain_on_self_hosted_is_automatically_verified(self):
"sso_enforcement": "saml", # ignore me
},
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
response_data = response.json()
self.assertEqual(response_data["domain"], "the.posthog.com")
self.assertEqual(response_data["verified_at"], "2021-08-08T20:20:08Z")
self.assertEqual(response_data["jit_provisioning_enabled"], False)
self.assertRegex(response_data["verification_challenge"], r"[0-9A-Za-z_-]{32}")

instance = OrganizationDomain.objects.get(id=response_data["id"])
self.assertEqual(instance.domain, "the.posthog.com")
self.assertEqual(
instance.verified_at,
datetime.datetime(2021, 8, 8, 20, 20, 8, tzinfo=ZoneInfo("UTC")),
)
self.assertEqual(instance.last_verification_retry, None)
self.assertEqual(instance.sso_enforcement, "")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

def test_cannot_create_duplicate_domain(self):
OrganizationDomain.objects.create(domain="i-registered-first.com", organization=self.another_org)
Expand Down Expand Up @@ -344,7 +330,9 @@ def test_only_admin_can_request_verification(self):

def test_can_update_jit_provisioning_and_sso_enforcement(self):
self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization.available_features = ["automatic_provisioning"]
self.organization_membership.save()
self.organization.save()
self.domain.verified_at = timezone.now()
self.domain.save()

Expand Down
1 change: 1 addition & 0 deletions posthog/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class AvailableFeature(str, Enum):
SURVEYS_STYLING = "surveys_styling"
SURVEYS_TEXT_HTML = "surveys_text_html"
SURVEYS_MULTIPLE_QUESTIONS = "surveys_multiple_questions"
AUTOMATIC_PROVISIONING = "automatic_provisioning"


TREND_FILTER_TYPE_ACTIONS = "actions"
Expand Down
6 changes: 0 additions & 6 deletions posthog/models/organization_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from django.db import models
from django.utils import timezone

from posthog.cloud_utils import is_cloud
from posthog.constants import AvailableFeature
from posthog.models import Organization
from posthog.models.utils import UUIDModel
Expand Down Expand Up @@ -161,11 +160,6 @@ def attempt_verification(self) -> Tuple["OrganizationDomain", bool]:
"""
Performs a DNS verification for a specific domain.
"""

if not is_cloud():
# We only do DNS validation on PostHog Cloud
return self._complete_verification()

try:
# TODO: Should we manually validate DNSSEC?
dns_response = dns.resolver.resolve(f"_posthog-challenge.{self.domain}", "TXT")
Expand Down

0 comments on commit 8b644e6

Please sign in to comment.