From d779154b5c197b72cde32f382600e85d6c7291ab Mon Sep 17 00:00:00 2001
From: Michael Matloka <dev@twixes.com>
Date: Mon, 9 Dec 2024 19:24:59 +0100
Subject: [PATCH] chore(environments): Update filtering of feature
 flags/surveys (#26677)

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
---
 posthog/api/early_access_feature.py           |  6 +--
 posthog/api/feature_flag.py                   | 26 ++++++-----
 posthog/api/organization_feature_flag.py      | 14 +++---
 posthog/api/survey.py                         | 16 +++----
 .../test_early_access_feature.ambr            |  3 +-
 .../test_organization_feature_flag.ambr       | 14 +++---
 .../api/test/__snapshots__/test_survey.ambr   | 45 ++++++++++---------
 .../management/commands/sync_feature_flags.py | 26 ++++++-----
 posthog/models/feature_flag/flag_analytics.py |  3 +-
 posthog/models/feedback/survey.py             |  2 +-
 10 files changed, 86 insertions(+), 69 deletions(-)

diff --git a/posthog/api/early_access_feature.py b/posthog/api/early_access_feature.py
index 004725393b4db..a7cc6cd42558e 100644
--- a/posthog/api/early_access_feature.py
+++ b/posthog/api/early_access_feature.py
@@ -270,9 +270,9 @@ def early_access_features(request: Request):
         )
 
     early_access_features = MinimalEarlyAccessFeatureSerializer(
-        EarlyAccessFeature.objects.filter(team_id=team.id, stage=EarlyAccessFeature.Stage.BETA).select_related(
-            "feature_flag"
-        ),
+        EarlyAccessFeature.objects.filter(
+            team__project_id=team.project_id, stage=EarlyAccessFeature.Stage.BETA
+        ).select_related("feature_flag"),
         many=True,
     ).data
 
diff --git a/posthog/api/feature_flag.py b/posthog/api/feature_flag.py
index 11711875b0d96..f20c1a4a6105a 100644
--- a/posthog/api/feature_flag.py
+++ b/posthog/api/feature_flag.py
@@ -355,7 +355,7 @@ def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag
 
         try:
             FeatureFlag.objects.filter(
-                key=validated_data["key"], team_id=self.context["team_id"], deleted=True
+                key=validated_data["key"], team__project_id=self.context["project_id"], deleted=True
             ).delete()
         except deletion.RestrictedError:
             raise exceptions.ValidationError(
@@ -385,7 +385,9 @@ def update(self, instance: FeatureFlag, validated_data: dict, *args: Any, **kwar
         request = self.context["request"]
         validated_key = validated_data.get("key", None)
         if validated_key:
-            FeatureFlag.objects.filter(key=validated_key, team=instance.team, deleted=True).delete()
+            FeatureFlag.objects.filter(
+                key=validated_key, team__project_id=instance.team.project_id, deleted=True
+            ).delete()
         self._update_filters(validated_data)
 
         analytics_dashboards = validated_data.pop("analytics_dashboards", None)
@@ -519,11 +521,11 @@ def safely_get_queryset(self, queryset) -> QuerySet:
                 .prefetch_related("surveys_linked_flag")
             )
 
-            survey_targeting_flags = Survey.objects.filter(team=self.team, targeting_flag__isnull=False).values_list(
-                "targeting_flag_id", flat=True
-            )
+            survey_targeting_flags = Survey.objects.filter(
+                team__project_id=self.team.project_id, targeting_flag__isnull=False
+            ).values_list("targeting_flag_id", flat=True)
             survey_internal_targeting_flags = Survey.objects.filter(
-                team=self.team, internal_targeting_flag__isnull=False
+                team__project_id=self.team.project_id, internal_targeting_flag__isnull=False
             ).values_list("internal_targeting_flag_id", flat=True)
             queryset = queryset.exclude(Q(id__in=survey_targeting_flags)).exclude(
                 Q(id__in=survey_internal_targeting_flags)
@@ -650,7 +652,9 @@ def my_flags(self, request: request.Request, **kwargs):
         if not request.user.is_authenticated:  # for mypy
             raise exceptions.NotAuthenticated()
 
-        feature_flags = list(FeatureFlag.objects.filter(team=self.team, deleted=False).order_by("-created_at"))
+        feature_flags = list(
+            FeatureFlag.objects.filter(team__project_id=self.team.project_id, deleted=False).order_by("-created_at")
+        )
 
         if not feature_flags:
             return Response([])
@@ -774,9 +778,9 @@ def evaluation_reasons(self, request: request.Request, **kwargs):
                 "evaluation": reasons[flag_key],
             }
 
-        disabled_flags = FeatureFlag.objects.filter(team_id=self.team_id, active=False, deleted=False).values_list(
-            "key", flat=True
-        )
+        disabled_flags = FeatureFlag.objects.filter(
+            team__project_id=self.project_id, active=False, deleted=False
+        ).values_list("key", flat=True)
 
         for flag_key in disabled_flags:
             flags_with_evaluation_reasons[flag_key] = {
@@ -859,7 +863,7 @@ def activity(self, request: request.Request, **kwargs):
         page = int(request.query_params.get("page", "1"))
 
         item_id = kwargs["pk"]
-        if not FeatureFlag.objects.filter(id=item_id, team_id=self.team_id).exists():
+        if not FeatureFlag.objects.filter(id=item_id, team__project_id=self.project_id).exists():
             return Response("", status=status.HTTP_404_NOT_FOUND)
 
         activity_page = load_activity(
diff --git a/posthog/api/organization_feature_flag.py b/posthog/api/organization_feature_flag.py
index 46b1652d7f9b5..c86781fc87876 100644
--- a/posthog/api/organization_feature_flag.py
+++ b/posthog/api/organization_feature_flag.py
@@ -1,4 +1,3 @@
-from django.core.exceptions import ObjectDoesNotExist
 from rest_framework.response import Response
 from posthog.api.utils import action
 from rest_framework import (
@@ -66,7 +65,7 @@ def copy_flags(self, request, *args, **kwargs):
 
         # Fetch the flag to copy
         try:
-            flag_to_copy = FeatureFlag.objects.get(key=feature_flag_key, team_id=from_project)
+            flag_to_copy = FeatureFlag.objects.get(key=feature_flag_key, team__project_id=from_project)
         except FeatureFlag.DoesNotExist:
             return Response({"error": "Feature flag to copy does not exist."}, status=status.HTTP_400_BAD_REQUEST)
 
@@ -82,9 +81,8 @@ def copy_flags(self, request, *args, **kwargs):
 
         for target_project_id in target_project_ids:
             # Target project does not exist
-            try:
-                target_project = Team.objects.get(id=target_project_id)
-            except ObjectDoesNotExist:
+            target_team = Team.objects.filter(project_id=target_project_id).first()
+            if target_team is None:
                 failed_projects.append(
                     {
                         "project_id": target_project_id,
@@ -134,7 +132,7 @@ def copy_flags(self, request, *args, **kwargs):
 
                         destination_cohort_serializer = CohortSerializer(
                             data={
-                                "team": target_project,
+                                "team": target_team,
                                 "name": original_cohort.name,
                                 "groups": [],
                                 "filters": {"properties": prop_group.to_dict()},
@@ -143,7 +141,7 @@ def copy_flags(self, request, *args, **kwargs):
                             },
                             context={
                                 "request": request,
-                                "team_id": target_project.id,
+                                "team_id": target_team.id,
                             },
                         )
                         destination_cohort_serializer.is_valid(raise_exception=True)
@@ -183,7 +181,7 @@ def copy_flags(self, request, *args, **kwargs):
             }
 
             existing_flag = FeatureFlag.objects.filter(
-                key=feature_flag_key, team_id=target_project_id, deleted=False
+                key=feature_flag_key, team__project_id=target_project_id, deleted=False
             ).first()
             # Update existing flag
             if existing_flag:
diff --git a/posthog/api/survey.py b/posthog/api/survey.py
index 79af99fe9d7ac..125f258f047f5 100644
--- a/posthog/api/survey.py
+++ b/posthog/api/survey.py
@@ -291,7 +291,7 @@ def validate(self, data):
 
         if (
             self.context["request"].method == "POST"
-            and Survey.objects.filter(name=data.get("name"), team_id=self.context["team_id"]).exists()
+            and Survey.objects.filter(name=data.get("name"), team__project_id=self.context["project_id"]).exists()
         ):
             raise serializers.ValidationError("There is already a survey with this name.", code="unique")
 
@@ -300,7 +300,7 @@ def validate(self, data):
         if (
             existing_survey
             and existing_survey.name != data.get("name")
-            and Survey.objects.filter(name=data.get("name"), team_id=self.context["team_id"])
+            and Survey.objects.filter(name=data.get("name"), team__project_id=self.context["project_id"])
             .exclude(id=existing_survey.id)
             .exists()
         ):
@@ -686,9 +686,9 @@ def destroy(self, request: Request, *args: Any, **kwargs: Any) -> Response:
 
     @action(methods=["GET"], detail=False, required_scopes=["survey:read"])
     def responses_count(self, request: request.Request, **kwargs):
-        earliest_survey_start_date = Survey.objects.filter(team_id=self.team_id).aggregate(Min("start_date"))[
-            "start_date__min"
-        ]
+        earliest_survey_start_date = Survey.objects.filter(team__project_id=self.project_id).aggregate(
+            Min("start_date")
+        )["start_date__min"]
         data = sync_execute(
             f"""
             SELECT JSONExtractString(properties, '$survey_id') as survey_id, count()
@@ -721,7 +721,7 @@ def activity(self, request: request.Request, **kwargs):
 
         item_id = kwargs["pk"]
 
-        if not Survey.objects.filter(id=item_id, team_id=self.team_id).exists():
+        if not Survey.objects.filter(id=item_id, team__project_id=self.project_id).exists():
             return Response(status=status.HTTP_404_NOT_FOUND)
 
         activity_page = load_activity(
@@ -742,7 +742,7 @@ def summarize_responses(self, request: request.Request, **kwargs):
 
         survey_id = kwargs["pk"]
 
-        if not Survey.objects.filter(id=survey_id, team_id=self.team_id).exists():
+        if not Survey.objects.filter(id=survey_id, team__project_id=self.project_id).exists():
             return Response(status=status.HTTP_404_NOT_FOUND)
 
         survey = self.get_object()
@@ -893,7 +893,7 @@ def surveys(request: Request):
         )
 
     surveys = SurveyAPISerializer(
-        Survey.objects.filter(team_id=team.id)
+        Survey.objects.filter(team__project_id=team.project_id)
         .exclude(archived=True)
         .select_related("linked_flag", "targeting_flag", "internal_targeting_flag")
         .prefetch_related("actions"),
diff --git a/posthog/api/test/__snapshots__/test_early_access_feature.ambr b/posthog/api/test/__snapshots__/test_early_access_feature.ambr
index 401bb26d24a9a..9e528aa3a5f0a 100644
--- a/posthog/api/test/__snapshots__/test_early_access_feature.ambr
+++ b/posthog/api/test/__snapshots__/test_early_access_feature.ambr
@@ -349,9 +349,10 @@
          "posthog_featureflag"."usage_dashboard_id",
          "posthog_featureflag"."has_enriched_analytics"
   FROM "posthog_earlyaccessfeature"
+  INNER JOIN "posthog_team" ON ("posthog_earlyaccessfeature"."team_id" = "posthog_team"."id")
   LEFT OUTER JOIN "posthog_featureflag" ON ("posthog_earlyaccessfeature"."feature_flag_id" = "posthog_featureflag"."id")
   WHERE ("posthog_earlyaccessfeature"."stage" = 'beta'
-         AND "posthog_earlyaccessfeature"."team_id" = 99999)
+         AND "posthog_team"."project_id" = 99999)
   '''
 # ---
 # name: TestPreviewList.test_early_access_features.2
diff --git a/posthog/api/test/__snapshots__/test_organization_feature_flag.ambr b/posthog/api/test/__snapshots__/test_organization_feature_flag.ambr
index 1fb0fcefbf20e..e20c804218649 100644
--- a/posthog/api/test/__snapshots__/test_organization_feature_flag.ambr
+++ b/posthog/api/test/__snapshots__/test_organization_feature_flag.ambr
@@ -865,8 +865,9 @@
          "posthog_featureflag"."usage_dashboard_id",
          "posthog_featureflag"."has_enriched_analytics"
   FROM "posthog_featureflag"
+  INNER JOIN "posthog_team" ON ("posthog_featureflag"."team_id" = "posthog_team"."id")
   WHERE ("posthog_featureflag"."key" = 'copied-flag-key'
-         AND "posthog_featureflag"."team_id" = 99999)
+         AND "posthog_team"."project_id" = 99999)
   LIMIT 21
   '''
 # ---
@@ -1696,8 +1697,9 @@
          "posthog_team"."external_data_workspace_id",
          "posthog_team"."external_data_workspace_last_synced_at"
   FROM "posthog_team"
-  WHERE "posthog_team"."id" = 99999
-  LIMIT 21
+  WHERE "posthog_team"."project_id" = 99999
+  ORDER BY "posthog_team"."id" ASC
+  LIMIT 1
   '''
 # ---
 # name: TestOrganizationFeatureFlagCopy.test_copy_feature_flag_create_new.50
@@ -1926,9 +1928,10 @@
          "posthog_featureflag"."usage_dashboard_id",
          "posthog_featureflag"."has_enriched_analytics"
   FROM "posthog_featureflag"
+  INNER JOIN "posthog_team" ON ("posthog_featureflag"."team_id" = "posthog_team"."id")
   WHERE (NOT "posthog_featureflag"."deleted"
          AND "posthog_featureflag"."key" = 'copied-flag-key'
-         AND "posthog_featureflag"."team_id" = 99999)
+         AND "posthog_team"."project_id" = 99999)
   ORDER BY "posthog_featureflag"."id" ASC
   LIMIT 1
   '''
@@ -2280,9 +2283,10 @@
          "posthog_featureflag"."usage_dashboard_id",
          "posthog_featureflag"."has_enriched_analytics"
   FROM "posthog_featureflag"
+  INNER JOIN "posthog_team" ON ("posthog_featureflag"."team_id" = "posthog_team"."id")
   WHERE ("posthog_featureflag"."deleted"
          AND "posthog_featureflag"."key" = 'copied-flag-key'
-         AND "posthog_featureflag"."team_id" = 99999)
+         AND "posthog_team"."project_id" = 99999)
   '''
 # ---
 # name: TestOrganizationFeatureFlagCopy.test_copy_feature_flag_create_new.9
diff --git a/posthog/api/test/__snapshots__/test_survey.ambr b/posthog/api/test/__snapshots__/test_survey.ambr
index 71dd2da339498..b29c45cab0c1b 100644
--- a/posthog/api/test/__snapshots__/test_survey.ambr
+++ b/posthog/api/test/__snapshots__/test_survey.ambr
@@ -27,10 +27,11 @@
   '''
   SELECT COUNT(*) AS "__count"
   FROM "posthog_survey"
+  INNER JOIN "posthog_team" ON ("posthog_survey"."team_id" = "posthog_team"."id")
   WHERE (NOT "posthog_survey"."archived"
          AND "posthog_survey"."end_date" IS NULL
          AND "posthog_survey"."start_date" IS NOT NULL
-         AND "posthog_survey"."team_id" = 99999
+         AND "posthog_team"."project_id" = 99999
          AND NOT ("posthog_survey"."type" = 'api'))
   '''
 # ---
@@ -482,10 +483,11 @@
   '''
   SELECT COUNT(*) AS "__count"
   FROM "posthog_survey"
+  INNER JOIN "posthog_team" ON ("posthog_survey"."team_id" = "posthog_team"."id")
   WHERE (NOT "posthog_survey"."archived"
          AND "posthog_survey"."end_date" IS NULL
          AND "posthog_survey"."start_date" IS NOT NULL
-         AND "posthog_survey"."team_id" = 99999
+         AND "posthog_team"."project_id" = 99999
          AND NOT ("posthog_survey"."type" = 'api'))
   '''
 # ---
@@ -630,21 +632,6 @@
          "posthog_featureflag"."ensure_experience_continuity",
          "posthog_featureflag"."usage_dashboard_id",
          "posthog_featureflag"."has_enriched_analytics",
-         T4."id",
-         T4."key",
-         T4."name",
-         T4."filters",
-         T4."rollout_percentage",
-         T4."team_id",
-         T4."created_by_id",
-         T4."created_at",
-         T4."deleted",
-         T4."active",
-         T4."rollback_conditions",
-         T4."performed_rollback",
-         T4."ensure_experience_continuity",
-         T4."usage_dashboard_id",
-         T4."has_enriched_analytics",
          T5."id",
          T5."key",
          T5."name",
@@ -659,12 +646,28 @@
          T5."performed_rollback",
          T5."ensure_experience_continuity",
          T5."usage_dashboard_id",
-         T5."has_enriched_analytics"
+         T5."has_enriched_analytics",
+         T6."id",
+         T6."key",
+         T6."name",
+         T6."filters",
+         T6."rollout_percentage",
+         T6."team_id",
+         T6."created_by_id",
+         T6."created_at",
+         T6."deleted",
+         T6."active",
+         T6."rollback_conditions",
+         T6."performed_rollback",
+         T6."ensure_experience_continuity",
+         T6."usage_dashboard_id",
+         T6."has_enriched_analytics"
   FROM "posthog_survey"
+  INNER JOIN "posthog_team" ON ("posthog_survey"."team_id" = "posthog_team"."id")
   LEFT OUTER JOIN "posthog_featureflag" ON ("posthog_survey"."linked_flag_id" = "posthog_featureflag"."id")
-  LEFT OUTER JOIN "posthog_featureflag" T4 ON ("posthog_survey"."targeting_flag_id" = T4."id")
-  LEFT OUTER JOIN "posthog_featureflag" T5 ON ("posthog_survey"."internal_targeting_flag_id" = T5."id")
-  WHERE ("posthog_survey"."team_id" = 99999
+  LEFT OUTER JOIN "posthog_featureflag" T5 ON ("posthog_survey"."targeting_flag_id" = T5."id")
+  LEFT OUTER JOIN "posthog_featureflag" T6 ON ("posthog_survey"."internal_targeting_flag_id" = T6."id")
+  WHERE ("posthog_team"."project_id" = 99999
          AND NOT ("posthog_survey"."archived"))
   '''
 # ---
diff --git a/posthog/management/commands/sync_feature_flags.py b/posthog/management/commands/sync_feature_flags.py
index 4e26061603691..22390a90f04e8 100644
--- a/posthog/management/commands/sync_feature_flags.py
+++ b/posthog/management/commands/sync_feature_flags.py
@@ -2,7 +2,7 @@
 
 from django.core.management.base import BaseCommand
 
-from posthog.models import FeatureFlag, Team, User
+from posthog.models import FeatureFlag, Project, User
 
 INACTIVE_FLAGS = [
     "cloud-announcement",
@@ -12,7 +12,7 @@
 
 
 class Command(BaseCommand):
-    help = "Add and enable all feature flags in frontend/src/lib/constants.tsx for all teams"
+    help = "Add and enable all feature flags in frontend/src/lib/constants.tsx for all projects"
 
     def handle(self, *args, **options):
         flags: dict[str, str] = {}
@@ -37,23 +37,27 @@ def handle(self, *args, **options):
                     parsing_flags = True
 
         first_user = cast(User, User.objects.first())
-        for team in Team.objects.all():
-            existing_flags = FeatureFlag.objects.filter(team=team).values_list("key", flat=True)
-            deleted_flags = FeatureFlag.objects.filter(team=team, deleted=True).values_list("key", flat=True)
+        for project in Project.objects.all():
+            existing_flags = FeatureFlag.objects.filter(team__project_id=project.id).values_list("key", flat=True)
+            deleted_flags = FeatureFlag.objects.filter(team__project_id=project.id, deleted=True).values_list(
+                "key", flat=True
+            )
             for flag in flags.keys():
                 flag_type = flags[flag]
                 is_enabled = flag not in INACTIVE_FLAGS
 
                 if flag in deleted_flags:
-                    ff = FeatureFlag.objects.filter(team=team, key=flag)[0]
+                    ff = FeatureFlag.objects.filter(team__project_id=project.id, key=flag)[0]
                     ff.deleted = False
                     ff.active = is_enabled
                     ff.save()
-                    print(f"Undeleted feature flag '{flag} for team {team.id} {' - ' + team.name if team.name else ''}")
+                    print(
+                        f"Undeleted feature flag '{flag} for project {project.id} {' - ' + project.name if project.name else ''}"
+                    )
                 elif flag not in existing_flags:
                     if flag_type == "multivariate":
                         FeatureFlag.objects.create(
-                            team=team,
+                            team=project.teams.first(),
                             rollout_percentage=100,
                             name=flag,
                             key=flag,
@@ -79,11 +83,13 @@ def handle(self, *args, **options):
                         )
                     else:
                         FeatureFlag.objects.create(
-                            team=team,
+                            team=project.teams.first(),
                             rollout_percentage=100,
                             name=flag,
                             key=flag,
                             created_by=first_user,
                             active=is_enabled,
                         )
-                    print(f"Created feature flag '{flag} for team {team.id} {' - ' + team.name if team.name else ''}")
+                    print(
+                        f"Created feature flag '{flag} for project {project.id} {' - ' + project.name if project.name else ''}"
+                    )
diff --git a/posthog/models/feature_flag/flag_analytics.py b/posthog/models/feature_flag/flag_analytics.py
index f62ed1934eca8..6579aa92d1041 100644
--- a/posthog/models/feature_flag/flag_analytics.py
+++ b/posthog/models/feature_flag/flag_analytics.py
@@ -139,9 +139,10 @@ def find_flags_with_enriched_analytics(begin: datetime, end: datetime):
     for row in result:
         team_id = row[0]
         flag_key = row[1]
+        team = Team.objects.only("project_id").get(id=team_id)
 
         try:
-            flag = FeatureFlag.objects.get(team_id=team_id, key=flag_key)
+            flag = FeatureFlag.objects.get(team__project_id=team.project_id, key=flag_key)
             if not flag.has_enriched_analytics:
                 flag.has_enriched_analytics = True
                 flag.save()
diff --git a/posthog/models/feedback/survey.py b/posthog/models/feedback/survey.py
index 2fc8088d6902b..b178fea2418e9 100644
--- a/posthog/models/feedback/survey.py
+++ b/posthog/models/feedback/survey.py
@@ -292,7 +292,7 @@ def update_survey_iterations(sender, instance, *args, **kwargs):
 def update_surveys_opt_in(sender, instance, **kwargs):
     active_surveys_count = (
         Survey.objects.filter(
-            team_id=instance.team_id,
+            team__project_id=instance.team.project_id,
             start_date__isnull=False,
             end_date__isnull=True,
             archived=False,