Skip to content

Commit

Permalink
Fix minor issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Twixes committed Aug 19, 2024
1 parent 538b81c commit 471c1a3
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 126 deletions.
3 changes: 2 additions & 1 deletion ee/api/test/test_billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,9 +777,10 @@ def mock_implementation(url: str, headers: Any = None, params: Any = None) -> Ma
# Create a demo project
self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization_membership.save()
self.assertEqual(Team.objects.count(), 1)
response = self.client.post("/api/projects/", {"name": "Test", "is_demo": True})
self.assertEqual(response.status_code, 201)
self.assertEqual(Team.objects.count(), 3)
self.assertEqual(Team.objects.count(), 2)

demo_team = Team.objects.filter(is_demo=True).first()

Expand Down
8 changes: 4 additions & 4 deletions ee/api/test/test_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_create_demo_project(self, *args):
self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization_membership.save()
response = self.client.post("/api/projects/", {"name": "Hedgebox", "is_demo": True})
self.assertEqual(Team.objects.count(), 3)
self.assertEqual(Team.objects.count(), 2)
self.assertEqual(response.status_code, 201)
response_data = response.json()
self.assertDictContainsSubset(
Expand All @@ -68,7 +68,7 @@ def test_create_two_demo_projects(self, *args):
self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization_membership.save()
response = self.client.post("/api/projects/", {"name": "Hedgebox", "is_demo": True})
self.assertEqual(Team.objects.count(), 3)
self.assertEqual(Team.objects.count(), 2)
self.assertEqual(response.status_code, 201)
response_data = response.json()
self.assertDictContainsSubset(
Expand All @@ -80,13 +80,13 @@ def test_create_two_demo_projects(self, *args):
response_data,
)
response_2 = self.client.post("/api/projects/", {"name": "Hedgebox", "is_demo": True})
self.assertEqual(Team.objects.count(), 3)
self.assertEqual(Team.objects.count(), 2)
response_2_data = response_2.json()
self.assertDictContainsSubset(
{
"type": "authentication_error",
"code": "permission_denied",
"detail": "You must upgrade your PostHog plan to be able to create and manage multiple projects.",
"detail": "You must upgrade your PostHog plan to be able to create and manage multiple projects or environments.",
},
response_2_data,
)
Expand Down
2 changes: 1 addition & 1 deletion posthog/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def api_not_found(request):


def register_grandfathered_environment_nested_viewset(
prefix: str, viewset: viewsets.GenericViewSet, basename: str, parents_query_lookups: list[str]
prefix: str, viewset: type[viewsets.GenericViewSet], basename: str, parents_query_lookups: list[str]
) -> tuple[NestedRegistryItem, NestedRegistryItem]:
"""
Register the environment-specific viewset under both /environments/:team_id/ (correct endpoint)
Expand Down
165 changes: 47 additions & 118 deletions posthog/api/project.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from datetime import timedelta
from functools import cached_property
from typing import Any, Optional, cast
Expand All @@ -8,18 +7,15 @@
from loginas.utils import is_impersonated_session
from rest_framework import exceptions, request, response, serializers, viewsets
from rest_framework.decorators import action
from rest_framework.fields import SkipField
from rest_framework.permissions import IsAuthenticated
from rest_framework.relations import PKOnlyObject
from rest_framework.utils import model_meta

from posthog.api.geoip import get_geoip_properties
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import TeamBasicSerializer
from posthog.api.shared import ProjectBasicSerializer
from posthog.api.team import PremiumMultiProjectPermissions, TeamSerializer, validate_team_attrs
from posthog.event_usage import report_user_action
from posthog.jwt import PosthogJwtAudience, encode_jwt
from posthog.models import Team, User
from posthog.models import User
from posthog.models.activity_logging.activity_log import (
Change,
Detail,
Expand Down Expand Up @@ -48,7 +44,7 @@
from posthog.utils import get_ip_address, get_week_start_for_country_code


class ProjectSerializer(serializers.ModelSerializer, UserPermissionsSerializerMixin):
class ProjectSerializer(ProjectBasicSerializer, UserPermissionsSerializerMixin):
effective_membership_level = serializers.SerializerMethodField() # Compat with TeamSerializer
has_group_types = serializers.SerializerMethodField() # Compat with TeamSerializer
live_events_token = serializers.SerializerMethodField() # Compat with TeamSerializer
Expand Down Expand Up @@ -122,118 +118,48 @@ class Meta:

team_passthrough_fields = {
"updated_at",
"uuid", # Compat with TeamSerializer
"api_token", # Compat with TeamSerializer
"app_urls", # Compat with TeamSerializer
"slack_incoming_webhook", # Compat with TeamSerializer
"anonymize_ips", # Compat with TeamSerializer
"completed_snippet_onboarding", # Compat with TeamSerializer
"ingested_event", # Compat with TeamSerializer
"test_account_filters", # Compat with TeamSerializer
"test_account_filters_default_checked", # Compat with TeamSerializer
"path_cleaning_filters", # Compat with TeamSerializer
"is_demo", # Compat with TeamSerializer
"timezone", # Compat with TeamSerializer
"data_attributes", # Compat with TeamSerializer
"person_display_name_properties", # Compat with TeamSerializer
"correlation_config", # Compat with TeamSerializer
"autocapture_opt_out", # Compat with TeamSerializer
"autocapture_exceptions_opt_in", # Compat with TeamSerializer
"autocapture_web_vitals_opt_in", # Compat with TeamSerializer
"autocapture_exceptions_errors_to_ignore", # Compat with TeamSerializer
"capture_console_log_opt_in", # Compat with TeamSerializer
"capture_performance_opt_in", # Compat with TeamSerializer
"session_recording_opt_in", # Compat with TeamSerializer
"session_recording_sample_rate", # Compat with TeamSerializer
"session_recording_minimum_duration_milliseconds", # Compat with TeamSerializer
"session_recording_linked_flag", # Compat with TeamSerializer
"session_recording_network_payload_capture_config", # Compat with TeamSerializer
"session_replay_config", # Compat with TeamSerializer
"access_control", # Compat with TeamSerializer
"week_start_day", # Compat with TeamSerializer
"primary_dashboard", # Compat with TeamSerializer
"live_events_columns", # Compat with TeamSerializer
"recording_domains", # Compat with TeamSerializer
"person_on_events_querying_enabled", # Compat with TeamSerializer
"inject_web_apps", # Compat with TeamSerializer
"extra_settings", # Compat with TeamSerializer
"modifiers", # Compat with TeamSerializer
"default_modifiers", # Compat with TeamSerializer
"has_completed_onboarding_for", # Compat with TeamSerializer
"surveys_opt_in", # Compat with TeamSerializer
"heatmaps_opt_in", # Compat with TeamSerializer
"uuid",
"api_token",
"app_urls",
"slack_incoming_webhook",
"anonymize_ips",
"completed_snippet_onboarding",
"ingested_event",
"test_account_filters",
"test_account_filters_default_checked",
"path_cleaning_filters",
"is_demo",
"timezone",
"data_attributes",
"person_display_name_properties",
"correlation_config",
"autocapture_opt_out",
"autocapture_exceptions_opt_in",
"autocapture_web_vitals_opt_in",
"autocapture_exceptions_errors_to_ignore",
"capture_console_log_opt_in",
"capture_performance_opt_in",
"session_recording_opt_in",
"session_recording_sample_rate",
"session_recording_minimum_duration_milliseconds",
"session_recording_linked_flag",
"session_recording_network_payload_capture_config",
"session_replay_config",
"access_control",
"week_start_day",
"primary_dashboard",
"live_events_columns",
"recording_domains",
"person_on_events_querying_enabled",
"inject_web_apps",
"extra_settings",
"modifiers",
"default_modifiers",
"has_completed_onboarding_for",
"surveys_opt_in",
"heatmaps_opt_in",
}

def get_fields(self):
declared_fields = copy.deepcopy(self._declared_fields)

info = model_meta.get_field_info(Project)
team_info = model_meta.get_field_info(Team)
for field_name, field in team_info.fields.items():
if field_name in info.fields:
continue
info.fields[field_name] = field
info.fields_and_pk[field_name] = field
for field_name, relation in team_info.forward_relations.items():
if field_name in info.forward_relations:
continue
info.forward_relations[field_name] = relation
info.relations[field_name] = relation
for accessor_name, relation in team_info.reverse_relations.items():
if accessor_name in info.reverse_relations:
continue
info.reverse_relations[accessor_name] = relation
info.relations[accessor_name] = relation

field_names = self.get_field_names(declared_fields, info)

extra_kwargs = self.get_extra_kwargs()
extra_kwargs, hidden_fields = self.get_uniqueness_extra_kwargs(field_names, declared_fields, extra_kwargs)

fields = {}
for field_name in field_names:
if field_name in declared_fields:
fields[field_name] = declared_fields[field_name]
continue
extra_field_kwargs = extra_kwargs.get(field_name, {})
source = extra_field_kwargs.get("source", "*")
if source == "*":
source = field_name
field_class, field_kwargs = self.build_field(source, info, model_class=Project, nested_depth=0)
field_kwargs = self.include_extra_kwargs(field_kwargs, extra_field_kwargs)
fields[field_name] = field_class(**field_kwargs)
fields.update(hidden_fields)
return fields

def build_field(self, field_name, info, model_class, nested_depth):
if field_name in self.Meta.team_passthrough_fields:
model_class = Team
return super().build_field(field_name, info, model_class, nested_depth)

def to_representation(self, instance):
"""
Object instance -> Dict of primitive datatypes.
"""
ret = {}
fields = self._readable_fields

for field in fields:
try:
attribute_source = instance
if field.field_name in self.Meta.team_passthrough_fields:
attribute_source = instance.passthrough_team
attribute = field.get_attribute(attribute_source)
except SkipField:
continue

check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute
if check_for_none is None:
ret[field.field_name] = None
else:
ret[field.field_name] = field.to_representation(attribute)

return ret

def get_effective_membership_level(self, project: Project) -> Optional[OrganizationMembership.Level]:
team = project.teams.get(pk=project.pk)
return self.user_permissions.team(team).effective_membership_level
Expand Down Expand Up @@ -342,6 +268,9 @@ def update(self, instance: Project, validated_data: dict[str, Any]) -> Project:
should_team_be_saved_too = True
setattr(team, attr, value)
else:
if attr == "name": # `name` should be updated on _both_ the Project and Team
should_team_be_saved_too = True
setattr(team, attr, value)
setattr(instance, attr, value)

instance.save()
Expand Down Expand Up @@ -386,7 +315,7 @@ def safely_get_queryset(self, queryset):

def get_serializer_class(self) -> type[serializers.BaseSerializer]:
if self.action == "list":
return TeamBasicSerializer
return ProjectBasicSerializer
return super().get_serializer_class()

# NOTE: Team permissions are somewhat complex so we override the underlying viewset's get_permissions method
Expand Down
109 changes: 109 additions & 0 deletions posthog/api/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
This module contains serializers that are used across other serializers for nested representations.
"""

import copy
from typing import Optional

from rest_framework import serializers

from posthog.models import Organization, Team, User
from posthog.models.organization import OrganizationMembership
from posthog.models.project import Project
from rest_framework.fields import SkipField
from rest_framework.relations import PKOnlyObject
from rest_framework.utils import model_meta


class UserBasicSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -36,6 +41,110 @@ def get_hedgehog_config(self, user: User) -> Optional[dict]:
return None


class ProjectBasicSerializer(serializers.ModelSerializer):
"""
Serializer for `Project` model with minimal attributes to speeed up loading and transfer times.
Also used for nested serializers.
"""

class Meta:
model = Project
fields = (
"id",
"uuid", # Compat with TeamSerializer
"organization",
"api_token", # Compat with TeamSerializer
"name",
"completed_snippet_onboarding", # Compat with TeamSerializer
"has_completed_onboarding_for", # Compat with TeamSerializer
"ingested_event", # Compat with TeamSerializer
"is_demo", # Compat with TeamSerializer
"timezone", # Compat with TeamSerializer
"access_control", # Compat with TeamSerializer
)
read_only_fields = fields
team_passthrough_fields = {
"uuid",
"api_token",
"completed_snippet_onboarding",
"has_completed_onboarding_for",
"ingested_event",
"is_demo",
"timezone",
"access_control",
}

def get_fields(self):
declared_fields = copy.deepcopy(self._declared_fields)

info = model_meta.get_field_info(Project)
team_info = model_meta.get_field_info(Team)
for field_name, field in team_info.fields.items():
if field_name in info.fields:
continue
info.fields[field_name] = field
info.fields_and_pk[field_name] = field
for field_name, relation in team_info.forward_relations.items():
if field_name in info.forward_relations:
continue
info.forward_relations[field_name] = relation
info.relations[field_name] = relation
for accessor_name, relation in team_info.reverse_relations.items():
if accessor_name in info.reverse_relations:
continue
info.reverse_relations[accessor_name] = relation
info.relations[accessor_name] = relation

field_names = self.get_field_names(declared_fields, info)

extra_kwargs = self.get_extra_kwargs()
extra_kwargs, hidden_fields = self.get_uniqueness_extra_kwargs(field_names, declared_fields, extra_kwargs)

fields = {}
for field_name in field_names:
if field_name in declared_fields:
fields[field_name] = declared_fields[field_name]
continue
extra_field_kwargs = extra_kwargs.get(field_name, {})
source = extra_field_kwargs.get("source", "*")
if source == "*":
source = field_name
field_class, field_kwargs = self.build_field(source, info, model_class=Project, nested_depth=0)
field_kwargs = self.include_extra_kwargs(field_kwargs, extra_field_kwargs)
fields[field_name] = field_class(**field_kwargs)
fields.update(hidden_fields)
return fields

def build_field(self, field_name, info, model_class, nested_depth):
if field_name in self.Meta.team_passthrough_fields:
model_class = Team
return super().build_field(field_name, info, model_class, nested_depth)

def to_representation(self, instance):
"""
Object instance -> Dict of primitive datatypes.
"""
ret = {}
fields = self._readable_fields

for field in fields:
try:
attribute_source = instance
if field.field_name in self.Meta.team_passthrough_fields:
attribute_source = instance.passthrough_team
attribute = field.get_attribute(attribute_source)
except SkipField:
continue

check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute
if check_for_none is None:
ret[field.field_name] = None
else:
ret[field.field_name] = field.to_representation(attribute)

return ret


class TeamBasicSerializer(serializers.ModelSerializer):
"""
Serializer for `Team` model with minimal attributes to speeed up loading and transfer times.
Expand Down
Loading

0 comments on commit 471c1a3

Please sign in to comment.