From 6201ab4d8df05c39264a00f46cc2e4771d60e8f6 Mon Sep 17 00:00:00 2001 From: Juraj Majerik Date: Fri, 17 Nov 2023 11:07:57 +0100 Subject: [PATCH] revert cohort ids back to str --- posthog/api/feature_flag.py | 11 ++++---- posthog/api/organization_feature_flag.py | 10 ++++--- posthog/api/test/test_feature_flag_utils.py | 6 ++--- .../test/test_organization_feature_flag.py | 7 ++--- posthog/models/cohort/util.py | 27 +++++++++---------- posthog/models/feature_flag/feature_flag.py | 18 +++++++------ 6 files changed, 41 insertions(+), 38 deletions(-) diff --git a/posthog/api/feature_flag.py b/posthog/api/feature_flag.py index 45bd8e077de1f..f513e9e74b6a4 100644 --- a/posthog/api/feature_flag.py +++ b/posthog/api/feature_flag.py @@ -503,11 +503,11 @@ def local_evaluation(self, request: request.Request, **kwargs): should_send_cohorts = "send_cohorts" in request.GET cohorts = {} - seen_cohorts_cache: Dict[int, Cohort] = {} + seen_cohorts_cache: Dict[str, Cohort] = {} if should_send_cohorts: seen_cohorts_cache = { - cohort.pk: cohort + str(cohort.pk): cohort for cohort in Cohort.objects.using(DATABASE_FOR_LOCAL_EVALUATION).filter( team_id=self.team_id, deleted=False ) @@ -547,11 +547,12 @@ def local_evaluation(self, request: request.Request, **kwargs): ): # don't duplicate queries for already added cohorts if id not in cohorts: - if id in seen_cohorts_cache: - cohort = seen_cohorts_cache[id] + parsed_cohort_id = str(id) + if parsed_cohort_id in seen_cohorts_cache: + cohort = seen_cohorts_cache[parsed_cohort_id] else: cohort = Cohort.objects.using(DATABASE_FOR_LOCAL_EVALUATION).get(id=id) - seen_cohorts_cache[id] = cohort + seen_cohorts_cache[parsed_cohort_id] = cohort if not cohort.is_static: cohorts[cohort.pk] = cohort.properties.to_dict() diff --git a/posthog/api/organization_feature_flag.py b/posthog/api/organization_feature_flag.py index aa2681a36a39a..391b036d3bc21 100644 --- a/posthog/api/organization_feature_flag.py +++ b/posthog/api/organization_feature_flag.py @@ -101,7 +101,7 @@ def copy_flags(self, request, *args, **kwargs): continue # get all linked cohorts, sorted by creation order - seen_cohorts_cache: Dict[int, Cohort] = {} + seen_cohorts_cache: Dict[str, Cohort] = {} sorted_cohort_ids = flag_to_copy.get_cohort_ids( seen_cohorts_cache=seen_cohorts_cache, sort_by_topological_order=True ) @@ -111,22 +111,24 @@ def copy_flags(self, request, *args, **kwargs): # create cohorts in the destination project if len(sorted_cohort_ids): for cohort_id in sorted_cohort_ids: - original_cohort = seen_cohorts_cache[cohort_id] + original_cohort = seen_cohorts_cache[str(cohort_id)] # search in destination project by name destination_cohort = Cohort.objects.filter( name=original_cohort.name, team_id=target_project_id, deleted=False ).first() + # create new cohort in the destination project if not destination_cohort: prop_group = Filter( data={"properties": original_cohort.filters["properties"], "is_simplified": True} ).property_groups + # we're going to reference the destination cohort - it must already exist! for prop in prop_group.flat: if prop.type == "cohort": original_child_cohort_id = prop.value - original_child_cohort = seen_cohorts_cache[original_child_cohort_id] + original_child_cohort = seen_cohorts_cache[str(original_child_cohort_id)] prop.value = name_to_dest_cohort_id[original_child_cohort.name] destination_cohort_serializer = CohortSerializer( @@ -155,7 +157,7 @@ def copy_flags(self, request, *args, **kwargs): for prop in props: if prop.get("type") == "cohort": original_cohort_id = prop["value"] - cohort_name = (seen_cohorts_cache[original_cohort_id]).name + cohort_name = (seen_cohorts_cache[str(original_cohort_id)]).name prop["value"] = name_to_dest_cohort_id[cohort_name] flag_data = { diff --git a/posthog/api/test/test_feature_flag_utils.py b/posthog/api/test/test_feature_flag_utils.py index b898bed86161c..690553b9ffe5b 100644 --- a/posthog/api/test/test_feature_flag_utils.py +++ b/posthog/api/test/test_feature_flag_utils.py @@ -50,9 +50,9 @@ def create_cohort(name): cohort_ids = {cohorts["a"].pk, cohorts["b"].pk, cohorts["c"].pk} seen_cohorts_cache = { - cohorts["a"].pk: cohorts["a"], - cohorts["b"].pk: cohorts["b"], - cohorts["c"].pk: cohorts["c"], + str(cohorts["a"].pk): cohorts["a"], + str(cohorts["b"].pk): cohorts["b"], + str(cohorts["c"].pk): cohorts["c"], } # (a)-->(c)-->(b) diff --git a/posthog/api/test/test_organization_feature_flag.py b/posthog/api/test/test_organization_feature_flag.py index 1c83af79e9c87..4913d0f86caf6 100644 --- a/posthog/api/test/test_organization_feature_flag.py +++ b/posthog/api/test/test_organization_feature_flag.py @@ -580,8 +580,8 @@ def connect(parent, child): # get topological order of the original cohorts original_cohorts_cache = {} for _, cohort in cohorts.items(): - original_cohorts_cache[cohort.id] = cohort - original_cohort_ids = set(original_cohorts_cache.keys()) + original_cohorts_cache[str(cohort.id)] = cohort + original_cohort_ids = {int(str_id) for str_id in original_cohorts_cache.keys()} topologically_sorted_original_cohort_ids = sort_cohorts_topologically( original_cohort_ids, original_cohorts_cache ) @@ -591,7 +591,8 @@ def connect(parent, child): topologically_sorted_original_cohort_ids_reversed = topologically_sorted_original_cohort_ids[::-1] def traverse(cohort, index): - expected_name = original_cohorts_cache[topologically_sorted_original_cohort_ids_reversed[index]].name + expected_cohort_id = topologically_sorted_original_cohort_ids_reversed[index] + expected_name = original_cohorts_cache[str(expected_cohort_id)].name self.assertEqual(expected_name, cohort.name) prop = cohort.filters["properties"]["values"][0] diff --git a/posthog/models/cohort/util.py b/posthog/models/cohort/util.py index 3f87e11993180..c0a5dfb4a3eb0 100644 --- a/posthog/models/cohort/util.py +++ b/posthog/models/cohort/util.py @@ -440,7 +440,7 @@ def get_all_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]: def get_dependent_cohorts( cohort: Cohort, using_database: str = "default", - seen_cohorts_cache: Optional[Dict[int, Cohort]] = None, + seen_cohorts_cache: Optional[Dict[str, Cohort]] = None, ) -> List[Cohort]: if seen_cohorts_cache is None: seen_cohorts_cache = {} @@ -449,32 +449,28 @@ def get_dependent_cohorts( seen_cohort_ids = set() seen_cohort_ids.add(cohort.id) - queue: List[int] = [] - for prop in cohort.properties.flat: - if prop.type == "cohort" and not isinstance(prop.value, list): - queue.append(int(prop.value)) + queue = [prop.value for prop in cohort.properties.flat if prop.type == "cohort"] while queue: cohort_id = queue.pop() try: - if cohort_id in seen_cohorts_cache: - cohort = seen_cohorts_cache[cohort_id] + parsed_cohort_id = str(cohort_id) + if parsed_cohort_id in seen_cohorts_cache: + cohort = seen_cohorts_cache[parsed_cohort_id] else: cohort = Cohort.objects.using(using_database).get(pk=cohort_id) - seen_cohorts_cache[cohort_id] = cohort + seen_cohorts_cache[parsed_cohort_id] = cohort if cohort.id not in seen_cohort_ids: cohorts.append(cohort) seen_cohort_ids.add(cohort.id) - for prop in cohort.properties.flat: - if prop.type == "cohort" and not isinstance(prop.value, list): - queue.append(int(prop.value)) + queue += [prop.value for prop in cohort.properties.flat if prop.type == "cohort"] except Cohort.DoesNotExist: continue return cohorts -def sort_cohorts_topologically(cohort_ids: Set[int], seen_cohorts_cache: Dict[int, Cohort]) -> List[int]: +def sort_cohorts_topologically(cohort_ids: Set[int], seen_cohorts_cache: Dict[str, Cohort]) -> List[int]: """ Sorts the given cohorts in an order where cohorts with no dependencies are placed first, followed by cohorts that depend on the preceding ones. It ensures that each cohort in the sorted list @@ -493,13 +489,13 @@ def traverse(cohort): # add child dependency_graph[cohort.id].append(int(prop.value)) - neighbor_cohort = seen_cohorts_cache[int(prop.value)] + neighbor_cohort = seen_cohorts_cache[str(prop.value)] if cohort.id not in seen: seen.add(cohort.id) traverse(neighbor_cohort) for cohort_id in cohort_ids: - cohort = seen_cohorts_cache[cohort_id] + cohort = seen_cohorts_cache[str(cohort_id)] traverse(cohort) # post-order DFS (children first, then the parent) @@ -508,7 +504,7 @@ def dfs(node, seen, sorted_arr): for neighbor in neighbors: if neighbor not in seen: dfs(neighbor, seen, sorted_arr) - sorted_arr.append(node) + sorted_arr.append(int(node)) seen.add(node) sorted_cohort_ids: List[int] = [] @@ -517,4 +513,5 @@ def dfs(node, seen, sorted_arr): if cohort_id not in seen: seen.add(cohort_id) dfs(cohort_id, seen, sorted_cohort_ids) + return sorted_cohort_ids diff --git a/posthog/models/feature_flag/feature_flag.py b/posthog/models/feature_flag/feature_flag.py index 418e01b5b15a6..36379563aa7f7 100644 --- a/posthog/models/feature_flag/feature_flag.py +++ b/posthog/models/feature_flag/feature_flag.py @@ -134,7 +134,7 @@ def get_filters(self): def transform_cohort_filters_for_easy_evaluation( self, using_database: str = "default", - seen_cohorts_cache: Optional[Dict[int, Cohort]] = None, + seen_cohorts_cache: Optional[Dict[str, Cohort]] = None, ): """ Expands cohort filters into person property filters when possible. @@ -174,11 +174,12 @@ def transform_cohort_filters_for_easy_evaluation( # We cannot expand this cohort condition if it's not the only property in its group. return self.conditions try: - if cohort_id in seen_cohorts_cache: - cohort = seen_cohorts_cache[cohort_id] + parsed_cohort_id = str(cohort_id) + if parsed_cohort_id in seen_cohorts_cache: + cohort = seen_cohorts_cache[parsed_cohort_id] else: cohort = Cohort.objects.using(using_database).get(pk=cohort_id) - seen_cohorts_cache[cohort_id] = cohort + seen_cohorts_cache[parsed_cohort_id] = cohort except Cohort.DoesNotExist: return self.conditions if not cohort_condition: @@ -258,7 +259,7 @@ def transform_cohort_filters_for_easy_evaluation( def get_cohort_ids( self, using_database: str = "default", - seen_cohorts_cache: Optional[Dict[int, Cohort]] = None, + seen_cohorts_cache: Optional[Dict[str, Cohort]] = None, sort_by_topological_order=False, ) -> List[int]: from posthog.models.cohort.util import get_dependent_cohorts, sort_cohorts_topologically @@ -273,11 +274,12 @@ def get_cohort_ids( if prop.get("type") == "cohort": cohort_id = prop.get("value") try: - if cohort_id in seen_cohorts_cache: - cohort: Cohort = seen_cohorts_cache[cohort_id] + parsed_cohort_id = str(cohort_id) + if parsed_cohort_id in seen_cohorts_cache: + cohort: Cohort = seen_cohorts_cache[parsed_cohort_id] else: cohort = Cohort.objects.using(using_database).get(pk=cohort_id) - seen_cohorts_cache[cohort_id] = cohort + seen_cohorts_cache[parsed_cohort_id] = cohort cohort_ids.add(cohort.pk) cohort_ids.update(