Skip to content

Commit

Permalink
revert cohort ids back to str
Browse files Browse the repository at this point in the history
  • Loading branch information
jurajmajerik committed Nov 17, 2023
1 parent 19a0e17 commit 6201ab4
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 38 deletions.
11 changes: 6 additions & 5 deletions posthog/api/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,11 +503,11 @@ def local_evaluation(self, request: request.Request, **kwargs):
should_send_cohorts = "send_cohorts" in request.GET

cohorts = {}
seen_cohorts_cache: Dict[int, Cohort] = {}
seen_cohorts_cache: Dict[str, Cohort] = {}

if should_send_cohorts:
seen_cohorts_cache = {
cohort.pk: cohort
str(cohort.pk): cohort
for cohort in Cohort.objects.using(DATABASE_FOR_LOCAL_EVALUATION).filter(
team_id=self.team_id, deleted=False
)
Expand Down Expand Up @@ -547,11 +547,12 @@ def local_evaluation(self, request: request.Request, **kwargs):
):
# don't duplicate queries for already added cohorts
if id not in cohorts:
if id in seen_cohorts_cache:
cohort = seen_cohorts_cache[id]
parsed_cohort_id = str(id)
if parsed_cohort_id in seen_cohorts_cache:
cohort = seen_cohorts_cache[parsed_cohort_id]
else:
cohort = Cohort.objects.using(DATABASE_FOR_LOCAL_EVALUATION).get(id=id)
seen_cohorts_cache[id] = cohort
seen_cohorts_cache[parsed_cohort_id] = cohort

if not cohort.is_static:
cohorts[cohort.pk] = cohort.properties.to_dict()
Expand Down
10 changes: 6 additions & 4 deletions posthog/api/organization_feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def copy_flags(self, request, *args, **kwargs):
continue

# get all linked cohorts, sorted by creation order
seen_cohorts_cache: Dict[int, Cohort] = {}
seen_cohorts_cache: Dict[str, Cohort] = {}
sorted_cohort_ids = flag_to_copy.get_cohort_ids(
seen_cohorts_cache=seen_cohorts_cache, sort_by_topological_order=True
)
Expand All @@ -111,22 +111,24 @@ def copy_flags(self, request, *args, **kwargs):
# create cohorts in the destination project
if len(sorted_cohort_ids):
for cohort_id in sorted_cohort_ids:
original_cohort = seen_cohorts_cache[cohort_id]
original_cohort = seen_cohorts_cache[str(cohort_id)]

# search in destination project by name
destination_cohort = Cohort.objects.filter(
name=original_cohort.name, team_id=target_project_id, deleted=False
).first()

# create new cohort in the destination project
if not destination_cohort:
prop_group = Filter(
data={"properties": original_cohort.filters["properties"], "is_simplified": True}
).property_groups

# we're going to reference the destination cohort - it must already exist!
for prop in prop_group.flat:
if prop.type == "cohort":
original_child_cohort_id = prop.value
original_child_cohort = seen_cohorts_cache[original_child_cohort_id]
original_child_cohort = seen_cohorts_cache[str(original_child_cohort_id)]
prop.value = name_to_dest_cohort_id[original_child_cohort.name]

destination_cohort_serializer = CohortSerializer(
Expand Down Expand Up @@ -155,7 +157,7 @@ def copy_flags(self, request, *args, **kwargs):
for prop in props:
if prop.get("type") == "cohort":
original_cohort_id = prop["value"]
cohort_name = (seen_cohorts_cache[original_cohort_id]).name
cohort_name = (seen_cohorts_cache[str(original_cohort_id)]).name
prop["value"] = name_to_dest_cohort_id[cohort_name]

flag_data = {
Expand Down
6 changes: 3 additions & 3 deletions posthog/api/test/test_feature_flag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def create_cohort(name):

cohort_ids = {cohorts["a"].pk, cohorts["b"].pk, cohorts["c"].pk}
seen_cohorts_cache = {
cohorts["a"].pk: cohorts["a"],
cohorts["b"].pk: cohorts["b"],
cohorts["c"].pk: cohorts["c"],
str(cohorts["a"].pk): cohorts["a"],
str(cohorts["b"].pk): cohorts["b"],
str(cohorts["c"].pk): cohorts["c"],
}

# (a)-->(c)-->(b)
Expand Down
7 changes: 4 additions & 3 deletions posthog/api/test/test_organization_feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,8 @@ def connect(parent, child):
# get topological order of the original cohorts
original_cohorts_cache = {}
for _, cohort in cohorts.items():
original_cohorts_cache[cohort.id] = cohort
original_cohort_ids = set(original_cohorts_cache.keys())
original_cohorts_cache[str(cohort.id)] = cohort
original_cohort_ids = {int(str_id) for str_id in original_cohorts_cache.keys()}
topologically_sorted_original_cohort_ids = sort_cohorts_topologically(
original_cohort_ids, original_cohorts_cache
)
Expand All @@ -591,7 +591,8 @@ def connect(parent, child):
topologically_sorted_original_cohort_ids_reversed = topologically_sorted_original_cohort_ids[::-1]

def traverse(cohort, index):
expected_name = original_cohorts_cache[topologically_sorted_original_cohort_ids_reversed[index]].name
expected_cohort_id = topologically_sorted_original_cohort_ids_reversed[index]
expected_name = original_cohorts_cache[str(expected_cohort_id)].name
self.assertEqual(expected_name, cohort.name)

prop = cohort.filters["properties"]["values"][0]
Expand Down
27 changes: 12 additions & 15 deletions posthog/models/cohort/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def get_all_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]:
def get_dependent_cohorts(
cohort: Cohort,
using_database: str = "default",
seen_cohorts_cache: Optional[Dict[int, Cohort]] = None,
seen_cohorts_cache: Optional[Dict[str, Cohort]] = None,
) -> List[Cohort]:
if seen_cohorts_cache is None:
seen_cohorts_cache = {}
Expand All @@ -449,32 +449,28 @@ def get_dependent_cohorts(
seen_cohort_ids = set()
seen_cohort_ids.add(cohort.id)

queue: List[int] = []
for prop in cohort.properties.flat:
if prop.type == "cohort" and not isinstance(prop.value, list):
queue.append(int(prop.value))
queue = [prop.value for prop in cohort.properties.flat if prop.type == "cohort"]

while queue:
cohort_id = queue.pop()
try:
if cohort_id in seen_cohorts_cache:
cohort = seen_cohorts_cache[cohort_id]
parsed_cohort_id = str(cohort_id)
if parsed_cohort_id in seen_cohorts_cache:
cohort = seen_cohorts_cache[parsed_cohort_id]
else:
cohort = Cohort.objects.using(using_database).get(pk=cohort_id)
seen_cohorts_cache[cohort_id] = cohort
seen_cohorts_cache[parsed_cohort_id] = cohort
if cohort.id not in seen_cohort_ids:
cohorts.append(cohort)
seen_cohort_ids.add(cohort.id)
for prop in cohort.properties.flat:
if prop.type == "cohort" and not isinstance(prop.value, list):
queue.append(int(prop.value))
queue += [prop.value for prop in cohort.properties.flat if prop.type == "cohort"]
except Cohort.DoesNotExist:
continue

return cohorts


def sort_cohorts_topologically(cohort_ids: Set[int], seen_cohorts_cache: Dict[int, Cohort]) -> List[int]:
def sort_cohorts_topologically(cohort_ids: Set[int], seen_cohorts_cache: Dict[str, Cohort]) -> List[int]:
"""
Sorts the given cohorts in an order where cohorts with no dependencies are placed first,
followed by cohorts that depend on the preceding ones. It ensures that each cohort in the sorted list
Expand All @@ -493,13 +489,13 @@ def traverse(cohort):
# add child
dependency_graph[cohort.id].append(int(prop.value))

neighbor_cohort = seen_cohorts_cache[int(prop.value)]
neighbor_cohort = seen_cohorts_cache[str(prop.value)]
if cohort.id not in seen:
seen.add(cohort.id)
traverse(neighbor_cohort)

for cohort_id in cohort_ids:
cohort = seen_cohorts_cache[cohort_id]
cohort = seen_cohorts_cache[str(cohort_id)]
traverse(cohort)

# post-order DFS (children first, then the parent)
Expand All @@ -508,7 +504,7 @@ def dfs(node, seen, sorted_arr):
for neighbor in neighbors:
if neighbor not in seen:
dfs(neighbor, seen, sorted_arr)
sorted_arr.append(node)
sorted_arr.append(int(node))
seen.add(node)

sorted_cohort_ids: List[int] = []
Expand All @@ -517,4 +513,5 @@ def dfs(node, seen, sorted_arr):
if cohort_id not in seen:
seen.add(cohort_id)
dfs(cohort_id, seen, sorted_cohort_ids)

return sorted_cohort_ids
18 changes: 10 additions & 8 deletions posthog/models/feature_flag/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_filters(self):
def transform_cohort_filters_for_easy_evaluation(
self,
using_database: str = "default",
seen_cohorts_cache: Optional[Dict[int, Cohort]] = None,
seen_cohorts_cache: Optional[Dict[str, Cohort]] = None,
):
"""
Expands cohort filters into person property filters when possible.
Expand Down Expand Up @@ -174,11 +174,12 @@ def transform_cohort_filters_for_easy_evaluation(
# We cannot expand this cohort condition if it's not the only property in its group.
return self.conditions
try:
if cohort_id in seen_cohorts_cache:
cohort = seen_cohorts_cache[cohort_id]
parsed_cohort_id = str(cohort_id)
if parsed_cohort_id in seen_cohorts_cache:
cohort = seen_cohorts_cache[parsed_cohort_id]
else:
cohort = Cohort.objects.using(using_database).get(pk=cohort_id)
seen_cohorts_cache[cohort_id] = cohort
seen_cohorts_cache[parsed_cohort_id] = cohort
except Cohort.DoesNotExist:
return self.conditions
if not cohort_condition:
Expand Down Expand Up @@ -258,7 +259,7 @@ def transform_cohort_filters_for_easy_evaluation(
def get_cohort_ids(
self,
using_database: str = "default",
seen_cohorts_cache: Optional[Dict[int, Cohort]] = None,
seen_cohorts_cache: Optional[Dict[str, Cohort]] = None,
sort_by_topological_order=False,
) -> List[int]:
from posthog.models.cohort.util import get_dependent_cohorts, sort_cohorts_topologically
Expand All @@ -273,11 +274,12 @@ def get_cohort_ids(
if prop.get("type") == "cohort":
cohort_id = prop.get("value")
try:
if cohort_id in seen_cohorts_cache:
cohort: Cohort = seen_cohorts_cache[cohort_id]
parsed_cohort_id = str(cohort_id)
if parsed_cohort_id in seen_cohorts_cache:
cohort: Cohort = seen_cohorts_cache[parsed_cohort_id]
else:
cohort = Cohort.objects.using(using_database).get(pk=cohort_id)
seen_cohorts_cache[cohort_id] = cohort
seen_cohorts_cache[parsed_cohort_id] = cohort

cohort_ids.add(cohort.pk)
cohort_ids.update(
Expand Down

0 comments on commit 6201ab4

Please sign in to comment.