From 471c1a3f0bb429c40cccae4e6822d80f96d77cee Mon Sep 17 00:00:00 2001 From: Michael Matloka Date: Sat, 3 Aug 2024 00:36:54 +0200 Subject: [PATCH] Fix minor issues --- ee/api/test/test_billing.py | 3 +- ee/api/test/test_team.py | 8 +- posthog/api/__init__.py | 2 +- posthog/api/project.py | 165 ++++++++-------------------- posthog/api/shared.py | 109 ++++++++++++++++++ posthog/models/test/test_project.py | 2 +- posthog/test/test_middleware.py | 2 +- 7 files changed, 165 insertions(+), 126 deletions(-) diff --git a/ee/api/test/test_billing.py b/ee/api/test/test_billing.py index 6062516c4ccef..00beee3f0a671 100644 --- a/ee/api/test/test_billing.py +++ b/ee/api/test/test_billing.py @@ -777,9 +777,10 @@ def mock_implementation(url: str, headers: Any = None, params: Any = None) -> Ma # Create a demo project self.organization_membership.level = OrganizationMembership.Level.ADMIN self.organization_membership.save() + self.assertEqual(Team.objects.count(), 1) response = self.client.post("/api/projects/", {"name": "Test", "is_demo": True}) self.assertEqual(response.status_code, 201) - self.assertEqual(Team.objects.count(), 3) + self.assertEqual(Team.objects.count(), 2) demo_team = Team.objects.filter(is_demo=True).first() diff --git a/ee/api/test/test_team.py b/ee/api/test/test_team.py index db9fb7efdbf37..9ace30a718798 100644 --- a/ee/api/test/test_team.py +++ b/ee/api/test/test_team.py @@ -51,7 +51,7 @@ def test_create_demo_project(self, *args): self.organization_membership.level = OrganizationMembership.Level.ADMIN self.organization_membership.save() response = self.client.post("/api/projects/", {"name": "Hedgebox", "is_demo": True}) - self.assertEqual(Team.objects.count(), 3) + self.assertEqual(Team.objects.count(), 2) self.assertEqual(response.status_code, 201) response_data = response.json() self.assertDictContainsSubset( @@ -68,7 +68,7 @@ def test_create_two_demo_projects(self, *args): self.organization_membership.level = OrganizationMembership.Level.ADMIN self.organization_membership.save() response = self.client.post("/api/projects/", {"name": "Hedgebox", "is_demo": True}) - self.assertEqual(Team.objects.count(), 3) + self.assertEqual(Team.objects.count(), 2) self.assertEqual(response.status_code, 201) response_data = response.json() self.assertDictContainsSubset( @@ -80,13 +80,13 @@ def test_create_two_demo_projects(self, *args): response_data, ) response_2 = self.client.post("/api/projects/", {"name": "Hedgebox", "is_demo": True}) - self.assertEqual(Team.objects.count(), 3) + self.assertEqual(Team.objects.count(), 2) response_2_data = response_2.json() self.assertDictContainsSubset( { "type": "authentication_error", "code": "permission_denied", - "detail": "You must upgrade your PostHog plan to be able to create and manage multiple projects.", + "detail": "You must upgrade your PostHog plan to be able to create and manage multiple projects or environments.", }, response_2_data, ) diff --git a/posthog/api/__init__.py b/posthog/api/__init__.py index f85d63fbb7669..a35f6cdf083ef 100644 --- a/posthog/api/__init__.py +++ b/posthog/api/__init__.py @@ -87,7 +87,7 @@ def api_not_found(request): def register_grandfathered_environment_nested_viewset( - prefix: str, viewset: viewsets.GenericViewSet, basename: str, parents_query_lookups: list[str] + prefix: str, viewset: type[viewsets.GenericViewSet], basename: str, parents_query_lookups: list[str] ) -> tuple[NestedRegistryItem, NestedRegistryItem]: """ Register the environment-specific viewset under both /environments/:team_id/ (correct endpoint) diff --git a/posthog/api/project.py b/posthog/api/project.py index 41fbfa3649ac1..e08a35cac0838 100644 --- a/posthog/api/project.py +++ b/posthog/api/project.py @@ -1,4 +1,3 @@ -import copy from datetime import timedelta from functools import cached_property from typing import Any, Optional, cast @@ -8,18 +7,15 @@ from loginas.utils import is_impersonated_session from rest_framework import exceptions, request, response, serializers, viewsets from rest_framework.decorators import action -from rest_framework.fields import SkipField from rest_framework.permissions import IsAuthenticated -from rest_framework.relations import PKOnlyObject -from rest_framework.utils import model_meta from posthog.api.geoip import get_geoip_properties from posthog.api.routing import TeamAndOrgViewSetMixin -from posthog.api.shared import TeamBasicSerializer +from posthog.api.shared import ProjectBasicSerializer from posthog.api.team import PremiumMultiProjectPermissions, TeamSerializer, validate_team_attrs from posthog.event_usage import report_user_action from posthog.jwt import PosthogJwtAudience, encode_jwt -from posthog.models import Team, User +from posthog.models import User from posthog.models.activity_logging.activity_log import ( Change, Detail, @@ -48,7 +44,7 @@ from posthog.utils import get_ip_address, get_week_start_for_country_code -class ProjectSerializer(serializers.ModelSerializer, UserPermissionsSerializerMixin): +class ProjectSerializer(ProjectBasicSerializer, UserPermissionsSerializerMixin): effective_membership_level = serializers.SerializerMethodField() # Compat with TeamSerializer has_group_types = serializers.SerializerMethodField() # Compat with TeamSerializer live_events_token = serializers.SerializerMethodField() # Compat with TeamSerializer @@ -122,118 +118,48 @@ class Meta: team_passthrough_fields = { "updated_at", - "uuid", # Compat with TeamSerializer - "api_token", # Compat with TeamSerializer - "app_urls", # Compat with TeamSerializer - "slack_incoming_webhook", # Compat with TeamSerializer - "anonymize_ips", # Compat with TeamSerializer - "completed_snippet_onboarding", # Compat with TeamSerializer - "ingested_event", # Compat with TeamSerializer - "test_account_filters", # Compat with TeamSerializer - "test_account_filters_default_checked", # Compat with TeamSerializer - "path_cleaning_filters", # Compat with TeamSerializer - "is_demo", # Compat with TeamSerializer - "timezone", # Compat with TeamSerializer - "data_attributes", # Compat with TeamSerializer - "person_display_name_properties", # Compat with TeamSerializer - "correlation_config", # Compat with TeamSerializer - "autocapture_opt_out", # Compat with TeamSerializer - "autocapture_exceptions_opt_in", # Compat with TeamSerializer - "autocapture_web_vitals_opt_in", # Compat with TeamSerializer - "autocapture_exceptions_errors_to_ignore", # Compat with TeamSerializer - "capture_console_log_opt_in", # Compat with TeamSerializer - "capture_performance_opt_in", # Compat with TeamSerializer - "session_recording_opt_in", # Compat with TeamSerializer - "session_recording_sample_rate", # Compat with TeamSerializer - "session_recording_minimum_duration_milliseconds", # Compat with TeamSerializer - "session_recording_linked_flag", # Compat with TeamSerializer - "session_recording_network_payload_capture_config", # Compat with TeamSerializer - "session_replay_config", # Compat with TeamSerializer - "access_control", # Compat with TeamSerializer - "week_start_day", # Compat with TeamSerializer - "primary_dashboard", # Compat with TeamSerializer - "live_events_columns", # Compat with TeamSerializer - "recording_domains", # Compat with TeamSerializer - "person_on_events_querying_enabled", # Compat with TeamSerializer - "inject_web_apps", # Compat with TeamSerializer - "extra_settings", # Compat with TeamSerializer - "modifiers", # Compat with TeamSerializer - "default_modifiers", # Compat with TeamSerializer - "has_completed_onboarding_for", # Compat with TeamSerializer - "surveys_opt_in", # Compat with TeamSerializer - "heatmaps_opt_in", # Compat with TeamSerializer + "uuid", + "api_token", + "app_urls", + "slack_incoming_webhook", + "anonymize_ips", + "completed_snippet_onboarding", + "ingested_event", + "test_account_filters", + "test_account_filters_default_checked", + "path_cleaning_filters", + "is_demo", + "timezone", + "data_attributes", + "person_display_name_properties", + "correlation_config", + "autocapture_opt_out", + "autocapture_exceptions_opt_in", + "autocapture_web_vitals_opt_in", + "autocapture_exceptions_errors_to_ignore", + "capture_console_log_opt_in", + "capture_performance_opt_in", + "session_recording_opt_in", + "session_recording_sample_rate", + "session_recording_minimum_duration_milliseconds", + "session_recording_linked_flag", + "session_recording_network_payload_capture_config", + "session_replay_config", + "access_control", + "week_start_day", + "primary_dashboard", + "live_events_columns", + "recording_domains", + "person_on_events_querying_enabled", + "inject_web_apps", + "extra_settings", + "modifiers", + "default_modifiers", + "has_completed_onboarding_for", + "surveys_opt_in", + "heatmaps_opt_in", } - def get_fields(self): - declared_fields = copy.deepcopy(self._declared_fields) - - info = model_meta.get_field_info(Project) - team_info = model_meta.get_field_info(Team) - for field_name, field in team_info.fields.items(): - if field_name in info.fields: - continue - info.fields[field_name] = field - info.fields_and_pk[field_name] = field - for field_name, relation in team_info.forward_relations.items(): - if field_name in info.forward_relations: - continue - info.forward_relations[field_name] = relation - info.relations[field_name] = relation - for accessor_name, relation in team_info.reverse_relations.items(): - if accessor_name in info.reverse_relations: - continue - info.reverse_relations[accessor_name] = relation - info.relations[accessor_name] = relation - - field_names = self.get_field_names(declared_fields, info) - - extra_kwargs = self.get_extra_kwargs() - extra_kwargs, hidden_fields = self.get_uniqueness_extra_kwargs(field_names, declared_fields, extra_kwargs) - - fields = {} - for field_name in field_names: - if field_name in declared_fields: - fields[field_name] = declared_fields[field_name] - continue - extra_field_kwargs = extra_kwargs.get(field_name, {}) - source = extra_field_kwargs.get("source", "*") - if source == "*": - source = field_name - field_class, field_kwargs = self.build_field(source, info, model_class=Project, nested_depth=0) - field_kwargs = self.include_extra_kwargs(field_kwargs, extra_field_kwargs) - fields[field_name] = field_class(**field_kwargs) - fields.update(hidden_fields) - return fields - - def build_field(self, field_name, info, model_class, nested_depth): - if field_name in self.Meta.team_passthrough_fields: - model_class = Team - return super().build_field(field_name, info, model_class, nested_depth) - - def to_representation(self, instance): - """ - Object instance -> Dict of primitive datatypes. - """ - ret = {} - fields = self._readable_fields - - for field in fields: - try: - attribute_source = instance - if field.field_name in self.Meta.team_passthrough_fields: - attribute_source = instance.passthrough_team - attribute = field.get_attribute(attribute_source) - except SkipField: - continue - - check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute - if check_for_none is None: - ret[field.field_name] = None - else: - ret[field.field_name] = field.to_representation(attribute) - - return ret - def get_effective_membership_level(self, project: Project) -> Optional[OrganizationMembership.Level]: team = project.teams.get(pk=project.pk) return self.user_permissions.team(team).effective_membership_level @@ -342,6 +268,9 @@ def update(self, instance: Project, validated_data: dict[str, Any]) -> Project: should_team_be_saved_too = True setattr(team, attr, value) else: + if attr == "name": # `name` should be updated on _both_ the Project and Team + should_team_be_saved_too = True + setattr(team, attr, value) setattr(instance, attr, value) instance.save() @@ -386,7 +315,7 @@ def safely_get_queryset(self, queryset): def get_serializer_class(self) -> type[serializers.BaseSerializer]: if self.action == "list": - return TeamBasicSerializer + return ProjectBasicSerializer return super().get_serializer_class() # NOTE: Team permissions are somewhat complex so we override the underlying viewset's get_permissions method diff --git a/posthog/api/shared.py b/posthog/api/shared.py index e37fe9de29297..49a390fad983f 100644 --- a/posthog/api/shared.py +++ b/posthog/api/shared.py @@ -2,12 +2,17 @@ This module contains serializers that are used across other serializers for nested representations. """ +import copy from typing import Optional from rest_framework import serializers from posthog.models import Organization, Team, User from posthog.models.organization import OrganizationMembership +from posthog.models.project import Project +from rest_framework.fields import SkipField +from rest_framework.relations import PKOnlyObject +from rest_framework.utils import model_meta class UserBasicSerializer(serializers.ModelSerializer): @@ -36,6 +41,110 @@ def get_hedgehog_config(self, user: User) -> Optional[dict]: return None +class ProjectBasicSerializer(serializers.ModelSerializer): + """ + Serializer for `Project` model with minimal attributes to speeed up loading and transfer times. + Also used for nested serializers. + """ + + class Meta: + model = Project + fields = ( + "id", + "uuid", # Compat with TeamSerializer + "organization", + "api_token", # Compat with TeamSerializer + "name", + "completed_snippet_onboarding", # Compat with TeamSerializer + "has_completed_onboarding_for", # Compat with TeamSerializer + "ingested_event", # Compat with TeamSerializer + "is_demo", # Compat with TeamSerializer + "timezone", # Compat with TeamSerializer + "access_control", # Compat with TeamSerializer + ) + read_only_fields = fields + team_passthrough_fields = { + "uuid", + "api_token", + "completed_snippet_onboarding", + "has_completed_onboarding_for", + "ingested_event", + "is_demo", + "timezone", + "access_control", + } + + def get_fields(self): + declared_fields = copy.deepcopy(self._declared_fields) + + info = model_meta.get_field_info(Project) + team_info = model_meta.get_field_info(Team) + for field_name, field in team_info.fields.items(): + if field_name in info.fields: + continue + info.fields[field_name] = field + info.fields_and_pk[field_name] = field + for field_name, relation in team_info.forward_relations.items(): + if field_name in info.forward_relations: + continue + info.forward_relations[field_name] = relation + info.relations[field_name] = relation + for accessor_name, relation in team_info.reverse_relations.items(): + if accessor_name in info.reverse_relations: + continue + info.reverse_relations[accessor_name] = relation + info.relations[accessor_name] = relation + + field_names = self.get_field_names(declared_fields, info) + + extra_kwargs = self.get_extra_kwargs() + extra_kwargs, hidden_fields = self.get_uniqueness_extra_kwargs(field_names, declared_fields, extra_kwargs) + + fields = {} + for field_name in field_names: + if field_name in declared_fields: + fields[field_name] = declared_fields[field_name] + continue + extra_field_kwargs = extra_kwargs.get(field_name, {}) + source = extra_field_kwargs.get("source", "*") + if source == "*": + source = field_name + field_class, field_kwargs = self.build_field(source, info, model_class=Project, nested_depth=0) + field_kwargs = self.include_extra_kwargs(field_kwargs, extra_field_kwargs) + fields[field_name] = field_class(**field_kwargs) + fields.update(hidden_fields) + return fields + + def build_field(self, field_name, info, model_class, nested_depth): + if field_name in self.Meta.team_passthrough_fields: + model_class = Team + return super().build_field(field_name, info, model_class, nested_depth) + + def to_representation(self, instance): + """ + Object instance -> Dict of primitive datatypes. + """ + ret = {} + fields = self._readable_fields + + for field in fields: + try: + attribute_source = instance + if field.field_name in self.Meta.team_passthrough_fields: + attribute_source = instance.passthrough_team + attribute = field.get_attribute(attribute_source) + except SkipField: + continue + + check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute + if check_for_none is None: + ret[field.field_name] = None + else: + ret[field.field_name] = field.to_representation(attribute) + + return ret + + class TeamBasicSerializer(serializers.ModelSerializer): """ Serializer for `Team` model with minimal attributes to speeed up loading and transfer times. diff --git a/posthog/models/test/test_project.py b/posthog/models/test/test_project.py index d6bfe0ed3a36a..1e2e0cef2167a 100644 --- a/posthog/models/test/test_project.py +++ b/posthog/models/test/test_project.py @@ -17,7 +17,7 @@ def test_create_project_with_team_no_team_fields(self): self.assertEqual( team.name, - "Default project", # TODO: When Environments are rolled out, ensure this says "Default environment" + "Test project", # TODO: When Environments are rolled out, ensure this says "Default environment" ) self.assertEqual(team.organization, self.organization) self.assertEqual(team.project, project) diff --git a/posthog/test/test_middleware.py b/posthog/test/test_middleware.py index ce8bfeb71b7bb..f5b4190ef4293 100644 --- a/posthog/test/test_middleware.py +++ b/posthog/test/test_middleware.py @@ -124,7 +124,7 @@ class TestAutoProjectMiddleware(APIBaseTest): @classmethod def setUpTestData(cls): super().setUpTestData() - cls.base_app_num_queries = 40 + cls.base_app_num_queries = 45 # Create another team that the user does have access to cls.second_team = create_team(organization=cls.organization, name="Second Life")