Skip to content

Commit

Permalink
feat(feature flags): copy cohorts linked to a flag (#18642)
Browse files Browse the repository at this point in the history
  • Loading branch information
jurajmajerik authored Nov 17, 2023
1 parent b8c188c commit 93e2011
Show file tree
Hide file tree
Showing 6 changed files with 433 additions and 13 deletions.
4 changes: 3 additions & 1 deletion frontend/src/scenes/feature-flags/featureFlagLogic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,9 @@ export const featureFlagLogic = kea<featureFlagLogicType>([
: 'copied'
lemonToast.success(`Feature flag ${operation} successfully!`)
} else {
lemonToast.error(`Error while saving feature flag: ${featureFlagCopy?.failed || featureFlagCopy}`)
lemonToast.error(
`Error while saving feature flag: ${JSON.stringify(featureFlagCopy?.failed) || featureFlagCopy}`
)
}

actions.loadProjectsWithCurrentFlag()
Expand Down
83 changes: 73 additions & 10 deletions posthog/api/organization_feature_flag.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from posthog.api.routing import StructuredViewSetMixin
from posthog.api.feature_flag import FeatureFlagSerializer
from posthog.api.feature_flag import CanEditFeatureFlag
from posthog.models import FeatureFlag, Team
from posthog.permissions import OrganizationMemberPermissions
from typing import Dict
from django.core.exceptions import ObjectDoesNotExist
from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated
Expand All @@ -12,6 +8,14 @@
viewsets,
status,
)
from posthog.api.cohort import CohortSerializer
from posthog.api.routing import StructuredViewSetMixin
from posthog.api.feature_flag import FeatureFlagSerializer
from posthog.api.feature_flag import CanEditFeatureFlag
from posthog.models import FeatureFlag, Team
from posthog.models.cohort import Cohort
from posthog.models.filters.filter import Filter
from posthog.permissions import OrganizationMemberPermissions


class OrganizationFeatureFlagView(
Expand Down Expand Up @@ -86,7 +90,7 @@ def copy_flags(self, request, *args, **kwargs):
for target_project_id in target_project_ids:
# Target project does not exist
try:
Team.objects.get(id=target_project_id)
target_project = Team.objects.get(id=target_project_id)
except ObjectDoesNotExist:
failed_projects.append(
{
Expand All @@ -96,10 +100,65 @@ def copy_flags(self, request, *args, **kwargs):
)
continue

context = {
"request": request,
"team_id": target_project_id,
}
# get all linked cohorts, sorted by creation order
seen_cohorts_cache: Dict[str, Cohort] = {}
sorted_cohort_ids = flag_to_copy.get_cohort_ids(
seen_cohorts_cache=seen_cohorts_cache, sort_by_topological_order=True
)

# destination cohort id is different from original cohort id - create mapping
name_to_dest_cohort_id: Dict[str, int] = {}
# create cohorts in the destination project
if len(sorted_cohort_ids):
for cohort_id in sorted_cohort_ids:
original_cohort = seen_cohorts_cache[str(cohort_id)]

# search in destination project by name
destination_cohort = Cohort.objects.filter(
name=original_cohort.name, team_id=target_project_id, deleted=False
).first()

# create new cohort in the destination project
if not destination_cohort:
prop_group = Filter(
data={"properties": original_cohort.properties.to_dict(), "is_simplified": True}
).property_groups

for prop in prop_group.flat:
if prop.type == "cohort":
original_child_cohort_id = prop.value
original_child_cohort = seen_cohorts_cache[str(original_child_cohort_id)]
prop.value = name_to_dest_cohort_id[original_child_cohort.name]

destination_cohort_serializer = CohortSerializer(
data={
"team": target_project,
"name": original_cohort.name,
"groups": [],
"filters": {"properties": prop_group.to_dict()},
"description": original_cohort.description,
"is_static": original_cohort.is_static,
},
context={
"request": request,
"team_id": target_project.id,
},
)
destination_cohort_serializer.is_valid(raise_exception=True)
destination_cohort = destination_cohort_serializer.save()

if destination_cohort is not None:
name_to_dest_cohort_id[original_cohort.name] = destination_cohort.id

# reference correct destination cohort ids in the flag
for group in flag_to_copy.conditions:
props = group.get("properties", [])
for prop in props:
if isinstance(prop, dict) and prop.get("type") == "cohort":
original_cohort_id = prop["value"]
cohort_name = (seen_cohorts_cache[str(original_cohort_id)]).name
prop["value"] = name_to_dest_cohort_id[cohort_name]

flag_data = {
"key": flag_to_copy.key,
"name": flag_to_copy.name,
Expand All @@ -109,6 +168,10 @@ def copy_flags(self, request, *args, **kwargs):
"ensure_experience_continuity": flag_to_copy.ensure_experience_continuity,
"deleted": False,
}
context = {
"request": request,
"team_id": target_project_id,
}

existing_flag = FeatureFlag.objects.filter(
key=feature_flag_key, team_id=target_project_id, deleted=False
Expand Down
73 changes: 73 additions & 0 deletions posthog/api/test/test_feature_flag_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Dict, Set
from posthog.test.base import (
APIBaseTest,
)
from posthog.models.cohort import Cohort
from posthog.models.cohort.util import sort_cohorts_topologically


class TestFeatureFlagUtils(APIBaseTest):
def setUp(self):
super().setUp()

def test_cohorts_sorted_topologically(self):
cohorts = {}

def create_cohort(name):
cohorts[name] = Cohort.objects.create(
team=self.team,
name=name,
filters={
"properties": {
"type": "AND",
"values": [
{"key": "name", "value": "test", "type": "person"},
],
}
},
)

create_cohort("a")
create_cohort("b")
create_cohort("c")

# (c)-->(b)
cohorts["c"].filters["properties"]["values"][0] = {
"key": "id",
"value": cohorts["b"].pk,
"type": "cohort",
"negation": True,
}
cohorts["c"].save()

# (a)-->(c)
cohorts["a"].filters["properties"]["values"][0] = {
"key": "id",
"value": cohorts["c"].pk,
"type": "cohort",
"negation": True,
}
cohorts["a"].save()

cohort_ids = {cohorts["a"].pk, cohorts["b"].pk, cohorts["c"].pk}
seen_cohorts_cache = {
str(cohorts["a"].pk): cohorts["a"],
str(cohorts["b"].pk): cohorts["b"],
str(cohorts["c"].pk): cohorts["c"],
}

# (a)-->(c)-->(b)
# create b first, since it doesn't depend on any other cohorts
# then c, because it depends on b
# then a, because it depends on c

# thus destination creation order: b, c, a
destination_creation_order = [cohorts["b"].pk, cohorts["c"].pk, cohorts["a"].pk]
topologically_sorted_cohort_ids = sort_cohorts_topologically(cohort_ids, seen_cohorts_cache)
self.assertEqual(topologically_sorted_cohort_ids, destination_creation_order)

def test_empty_cohorts_set(self):
cohort_ids: Set[int] = set()
seen_cohorts_cache: Dict[str, Cohort] = {}
topologically_sorted_cohort_ids = sort_cohorts_topologically(cohort_ids, seen_cohorts_cache)
self.assertEqual(topologically_sorted_cohort_ids, [])
Loading

0 comments on commit 93e2011

Please sign in to comment.