diff --git a/posthog/warehouse/api/modeling.py b/posthog/warehouse/api/modeling.py index 309f773df645ba..b5a2d8953e80a2 100644 --- a/posthog/warehouse/api/modeling.py +++ b/posthog/warehouse/api/modeling.py @@ -1,4 +1,4 @@ -from rest_framework import serializers, viewsets +from rest_framework import request, response, serializers, viewsets from posthog.api.routing import TeamAndOrgViewSetMixin from posthog.api.shared import UserBasicSerializer @@ -12,8 +12,18 @@ class Meta: model = DataWarehouseModelPath -class DataWarehouseModelViewSet(TeamAndOrgViewSetMixin, viewsets.ReadOnlyModelViewSet): +class DataWarehouseModelPathViewSet(TeamAndOrgViewSetMixin, viewsets.ReadOnlyModelViewSet): scope_object = "INTERNAL" queryset = DataWarehouseModelPath.objects.all() serializer_class = DataWarehouseModelPathSerializer + + +class DataWarehouseModelDagViewSet(TeamAndOrgViewSetMixin, viewsets.ViewSet): + scope_object = "INTERNAL" + + def list(self, request: request.Request, *args, **kwargs) -> response.Response: + """Return this team's DAG as a set of edges and nodes""" + dag = DataWarehouseModelPath.objects.get_dag(self.team) + + return response.Response({"edges": dag.edges, "nodes": dag.nodes}) diff --git a/posthog/warehouse/api/test/test_modeling.py b/posthog/warehouse/api/test/test_modeling.py new file mode 100644 index 00000000000000..eb62b6a1d47fe0 --- /dev/null +++ b/posthog/warehouse/api/test/test_modeling.py @@ -0,0 +1,43 @@ +from posthog.test.base import APIBaseTest +from posthog.warehouse.models import DataWarehouseModelPath, DataWarehouseSavedQuery + + +class TestDag(APIBaseTest): + def test_get_dag(self): + parent_query = """\ + select + events.event, + persons.properties + from events + left join persons on events.person_id = persons.id + where events.event = 'login' and person.pdi != 'some_distinct_id' + """ + parent_saved_query = DataWarehouseSavedQuery.objects.create( + team=self.team, + name="my_model", + query={"query": parent_query}, + ) + child_saved_query = DataWarehouseSavedQuery.objects.create( + team=self.team, + name="my_model_child", + query={"query": "select * from my_model as my_other_model"}, + ) + DataWarehouseModelPath.objects.create_from_saved_query(parent_saved_query) + DataWarehouseModelPath.objects.create_from_saved_query(child_saved_query) + + response = self.client.get( + f"/api/projects/{self.team.id}/warehouse_dag", + ) + self.assertEqual(response.status_code, 200, response.content) + dag = response.json() + + self.assertIn([parent_saved_query.id.hex, child_saved_query.id.hex], dag["edges"]) + self.assertIn(["events", parent_saved_query.id.hex], dag["edges"]) + self.assertIn(["persons", parent_saved_query.id.hex], dag["edges"]) + self.assertEqual(len(dag["edges"]), 3) + + self.assertIn([child_saved_query.id.hex, "SavedQuery"], dag["nodes"]) + self.assertIn([parent_saved_query.id.hex, "SavedQuery"], dag["nodes"]) + self.assertIn(["events", "PostHog"], dag["nodes"]) + self.assertIn(["persons", "PostHog"], dag["nodes"]) + self.assertEqual(len(dag["nodes"]), 4)