diff --git a/posthog/warehouse/api/saved_query.py b/posthog/warehouse/api/saved_query.py index db437861f967a..0c233001af365 100644 --- a/posthog/warehouse/api/saved_query.py +++ b/posthog/warehouse/api/saved_query.py @@ -160,25 +160,25 @@ def ancestors(self, request: request.Request, *args, **kwargs) -> response.Respo look further back into the ancestor tree. If `level` overshoots (i.e. points to only ancestors beyond the root), we return an empty list. """ - level = request.data.get("level", 1) + up_to_level = request.data.get("level", None) saved_query = self.get_object() saved_query_id = saved_query.id.hex - lquery = f"*{{{level},}}.{saved_query_id}" + lquery = f"*{{1,}}.{saved_query_id}" paths = DataWarehouseModelPath.objects.filter(team=saved_query.team, path__lquery=lquery) if not paths: return response.Response({"ancestors": []}) - ancestors = set() + ancestors: set[str] = set() for model_path in paths: - offset = len(model_path.path) - level - 1 # -1 corrects for level being 1-indexed + if up_to_level is None: + start = 0 + else: + start = (int(up_to_level) * -1) - 1 - if offset < 0: - continue - - ancestors.add(model_path.path[offset]) + ancestors = ancestors.union(model_path.path[start:-1]) return response.Response({"ancestors": ancestors}) @@ -190,25 +190,25 @@ def descendants(self, request: request.Request, *args, **kwargs) -> response.Res look further ahead into the descendants tree. If `level` overshoots (i.e. points to only descendants further than a leaf), we return an empty list. """ - level = request.data.get("level", 1) + up_to_level = request.data.get("level", None) saved_query = self.get_object() saved_query_id = saved_query.id.hex - lquery = f"*.{saved_query_id}.*{{{level},}}" + lquery = f"*.{saved_query_id}.*{{1,}}" paths = DataWarehouseModelPath.objects.filter(team=saved_query.team, path__lquery=lquery) if not paths: return response.Response({"descendants": []}) - descendants = set() - + descendants: set[str] = set() for model_path in paths: - offset = model_path.path.index(saved_query_id) + level - - if offset > len(model_path.path): - continue + start = model_path.path.index(saved_query_id) + 1 + if up_to_level is None: + end = len(model_path.path) + else: + end = start + up_to_level - descendants.add(model_path.path[offset]) + descendants = descendants.union(model_path.path[start:end]) return response.Response({"descendants": descendants}) diff --git a/posthog/warehouse/api/test/test_saved_query.py b/posthog/warehouse/api/test/test_saved_query.py index 80deaad72ca3a..a0abdf02c5e98 100644 --- a/posthog/warehouse/api/test/test_saved_query.py +++ b/posthog/warehouse/api/test/test_saved_query.py @@ -230,24 +230,33 @@ def test_ancestors(self): self.assertEqual(response.status_code, 200, response.content) child_ancestors = response.json()["ancestors"] - self.assertEqual(child_ancestors, [uuid.UUID(saved_query_parent_id).hex]) + child_ancestors.sort() + self.assertEqual(child_ancestors, sorted([uuid.UUID(saved_query_parent_id).hex, "events", "persons"])) response = self.client.post( - f"/api/projects/{self.team.id}/warehouse_saved_queries/{saved_query_child_id}/ancestors", {"level": 2} + f"/api/projects/{self.team.id}/warehouse_saved_queries/{saved_query_child_id}/ancestors", {"level": 1} ) + self.assertEqual(response.status_code, 200, response.content) + child_ancestors_level_1 = response.json()["ancestors"] + child_ancestors_level_1.sort() + self.assertEqual(child_ancestors_level_1, [uuid.UUID(saved_query_parent_id).hex]) + + response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_saved_queries/{saved_query_child_id}/ancestors", {"level": 2} + ) self.assertEqual(response.status_code, 200, response.content) child_ancestors_level_2 = response.json()["ancestors"] child_ancestors_level_2.sort() - self.assertEqual(child_ancestors_level_2, ["events", "persons"]) + self.assertEqual(child_ancestors_level_2, sorted([uuid.UUID(saved_query_parent_id).hex, "events", "persons"])) response = self.client.post( f"/api/projects/{self.team.id}/warehouse_saved_queries/{saved_query_child_id}/ancestors", {"level": 10} ) - self.assertEqual(response.status_code, 200, response.content) child_ancestors_level_10 = response.json()["ancestors"] - self.assertEqual(child_ancestors_level_10, []) + child_ancestors_level_10.sort() + self.assertEqual(child_ancestors_level_2, sorted([uuid.UUID(saved_query_parent_id).hex, "events", "persons"])) def test_descendants(self): query = """\ @@ -281,23 +290,69 @@ def test_descendants(self): }, ) + response_grand_child = self.client.post( + f"/api/projects/{self.team.id}/warehouse_saved_queries/", + { + "name": "event_view_3", + "query": { + "kind": "HogQLQuery", + "query": "select event as event from event_view_2", + }, + }, + ) + self.assertEqual(response_parent.status_code, 201, response_parent.content) self.assertEqual(response_child.status_code, 201, response_child.content) + self.assertEqual(response_grand_child.status_code, 201, response_grand_child.content) saved_query_parent_id = response_parent.json()["id"] saved_query_child_id = response_child.json()["id"] + saved_query_grand_child_id = response_grand_child.json()["id"] response = self.client.post( f"/api/projects/{self.team.id}/warehouse_saved_queries/{saved_query_parent_id}/descendants", ) self.assertEqual(response.status_code, 200, response.content) parent_descendants = response.json()["descendants"] - self.assertEqual(parent_descendants, [uuid.UUID(saved_query_child_id).hex]) + self.assertEqual( + sorted(parent_descendants), + sorted([uuid.UUID(saved_query_child_id).hex, uuid.UUID(saved_query_grand_child_id).hex]), + ) + + response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_saved_queries/{saved_query_parent_id}/descendants", {"level": 1} + ) + + self.assertEqual(response.status_code, 200, response.content) + parent_descendants_level_1 = response.json()["descendants"] + self.assertEqual( + parent_descendants_level_1, + [uuid.UUID(saved_query_child_id).hex], + ) + + response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_saved_queries/{saved_query_parent_id}/descendants", {"level": 2} + ) + + self.assertEqual(response.status_code, 200, response.content) + parent_descendants_level_2 = response.json()["descendants"] + self.assertEqual( + sorted(parent_descendants_level_2), + sorted([uuid.UUID(saved_query_child_id).hex, uuid.UUID(saved_query_grand_child_id).hex]), + ) response = self.client.post( f"/api/projects/{self.team.id}/warehouse_saved_queries/{saved_query_child_id}/descendants", ) + self.assertEqual(response.status_code, 200, response.content) + child_ancestors = response.json()["descendants"] + self.assertEqual(child_ancestors, [uuid.UUID(saved_query_grand_child_id).hex]) + + response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_saved_queries/{saved_query_grand_child_id}/descendants", + ) + self.assertEqual(response.status_code, 200, response.content) child_ancestors = response.json()["descendants"] self.assertEqual(child_ancestors, [])