diff --git a/posthog/migrations/0453_datawarehousemodelpath_and_more.py b/posthog/migrations/0453_datawarehousemodelpath_and_more.py index 63e3b487ce5b13..3a7bc811fb6db0 100644 --- a/posthog/migrations/0453_datawarehousemodelpath_and_more.py +++ b/posthog/migrations/0453_datawarehousemodelpath_and_more.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.14 on 2024-08-07 13:24 +# Generated by Django 4.2.14 on 2024-08-09 17:28 import django.contrib.postgres.indexes import django.db.models.constraints @@ -59,7 +59,6 @@ class Migration(migrations.Migration): options={ "indexes": [ models.Index(fields=["team_id", "path"], name="team_id_path"), - models.Index(fields=["team_id", "table_id"], name="team_id_table_id"), models.Index(fields=["team_id", "saved_query_id"], name="team_id_saved_query_id"), django.contrib.postgres.indexes.GistIndex(models.F("path"), name="model_path_path"), ], diff --git a/posthog/warehouse/models/modeling.py b/posthog/warehouse/models/modeling.py index 3e948c5b08a98f..3acbe176c44fd1 100644 --- a/posthog/warehouse/models/modeling.py +++ b/posthog/warehouse/models/modeling.py @@ -1,6 +1,6 @@ import collections.abc import dataclasses -import typing +import enum import uuid from django.contrib.postgres import indexes as pg_indexes @@ -8,6 +8,8 @@ from django.db import connection, models, transaction from posthog.hogql import ast +from posthog.hogql.database.database import Database, create_hogql_database +from posthog.hogql.parser import parse_select from posthog.models.team import Team from posthog.models.user import User from posthog.models.utils import ( @@ -20,23 +22,16 @@ from posthog.warehouse.models.datawarehouse_saved_query import DataWarehouseSavedQuery from posthog.warehouse.models.table import DataWarehouseTable -POSTHOG_ROOT_SOURCES = { - "events", - "groups", - "persons", - "person_distinct_ids", - "session_replay_events", - "cohort_people", - "static_cohort_people", - "log_entries", - "sessions", - "heatmaps", -} - LabelPath = list[str] class LabelTreeField(models.Field): + """A Django model field for a PostgreSQL label tree. + + We represent label trees in Python as a list of strings, each item + in the list being one of the labels of the underlying ltree. + """ + description = "A PostgreSQL label tree field provided by the ltree extension" def db_type(self, connection): @@ -78,19 +73,13 @@ def as_sql(self, compiler, connection): LabelTreeField.register_lookup(LabelQuery) -def get_hogql_query(query: str, team: Team) -> ast.SelectQuery | ast.SelectUnionQuery: - from posthog.hogql.parser import parse_select - - parsed_query = parse_select(query) - return parsed_query +def get_parents_from_model_query(model_query: str) -> set[str]: + """Get parents from a given query. - -def get_parents_from_model_query(model_query: str, team: Team): - """Get parent models from a given query. - - This corresponds to any names in the FROM clause of the query. + The parents of a query are any names in the `FROM` clause of the query. """ - hogql_query = get_hogql_query(query=model_query, team=team) + + hogql_query = parse_select(model_query) if isinstance(hogql_query, ast.SelectUnionQuery): queries = hogql_query.select_queries @@ -98,9 +87,20 @@ def get_parents_from_model_query(model_query: str, team: Team): queries = [hogql_query] parents = set() + ctes = set() while queries: query = queries.pop() + + if query.ctes is not None: + for name, cte in query.ctes.items(): + ctes.add(name) + + if isinstance(cte.expr, ast.SelectUnionQuery): + queries.extend(cte.expr.select_queries) + elif isinstance(cte.expr, ast.SelectQuery): + queries.append(cte.expr) + join = query.select_from if join is None: @@ -116,12 +116,22 @@ def get_parents_from_model_query(model_query: str, team: Team): queries.extend(join.table.select_queries) while join is not None: - parents.add(join.table.chain[0]) # type: ignore + parent_name = join.table.chain[0] # type: ignore + + if parent_name not in ctes and isinstance(parent_name, str): + parents.add(parent_name) + join = join.next_join + return parents -NodeType = typing.Literal["SavedQuery", "Table", "PostHog"] +class NodeType(enum.Enum): + SAVED_QUERY = "SavedQuery" + POSTHOG = "PostHog" + TABLE = "Table" + + NodeId = str Node = tuple[NodeId, NodeType] Edge = tuple[NodeId, NodeId] @@ -135,7 +145,7 @@ class DAG: nodes: set[Node] -INSERT_QUERY = """\ +UPDATE_PATHS_QUERY = """\ insert into posthog_datawarehousemodelpath ( id, team_id, @@ -191,120 +201,239 @@ class DAG: """ +class UnknownParentError(Exception): + """Exception raised when the parent for a model is not found.""" + + def __init__(self, parent: str, query: str): + super().__init__( + f"The parent name {parent} does not correspond to an existing PostHog table, Data Warehouse Table, or Data Warehouse Saved Query." + ) + self.query = query + + +class ModelPathAlreadyExistsError(Exception): + """Exception raised when trying to create paths for a model that already has some.""" + + def __init__(self, model_name: str): + super().__init__(f"Model {model_name} cannot be created as it already exists") + + +class ModelPathDoesNotExistError(Exception): + """Exception raised when trying to update paths for a model that doesn't exist.""" + + def __init__(self, model_name: str): + super().__init__(f"Model {model_name} doesn't exist") + + class DataWarehouseModelPathManager(models.Manager["DataWarehouseModelPath"]): - def create_from_saved_query_instance(self, saved_query: DataWarehouseSavedQuery) -> "list[DataWarehouseModelPath]": - """Create a new model path from a new `DataWarehouseSavedQuery`. + """A model manager that implements some common path operations.""" + + def create_from_saved_query(self, saved_query: DataWarehouseSavedQuery) -> "list[DataWarehouseModelPath]": + """Create one or more model paths from a new `DataWarehouseSavedQuery`. - Creating from a new `DataWarehouseSavedQuery` is straight-forward as we don't have to worry - about this model having its own children paths that need updating: We are only adding a leaf - node to all ancestor paths. + Creating one or more model paths from a new `DataWarehouseSavedQuery` is straight-forward as we + don't have to worry about this model having its own children paths that need updating: We are + only adding a leaf node to all parents' paths. We check this model indeed does not exist to + ensure that is the case. Raises: ValueError: If no paths exists for the provided `DataWarehouseSavedQuery`. """ if self.filter(team=saved_query.team, saved_query=saved_query).exists(): - raise ValueError("Model cannot be created as it already exists, use `update_from_saved_query_instance`") + raise ModelPathAlreadyExistsError(saved_query.name) - return self.create_from_saved_query( - saved_query=saved_query.query["query"], + return self.create_leaf_paths_from_query( + query=saved_query.query["query"], team=saved_query.team, created_by=saved_query.created_by, + label=saved_query.id.hex, saved_query_id=saved_query.id, ) - def create_from_saved_query( - self, saved_query: str, team: Team, saved_query_id: uuid.UUID, created_by: User | None = None + def create_leaf_paths_from_query( + self, + query: str, + team: Team, + label: str, + created_by: User | None = None, + saved_query_id: uuid.UUID | None = None, + table_id: uuid.UUID | None = None, + deleted: bool = False, ) -> "list[DataWarehouseModelPath]": + """Create all paths to a new leaf model. + + A path will be created for each parent, as extracted from the given query. + """ base_params = { "team": team, "created_by": created_by, "saved_query_id": saved_query_id, - "table_id": None, + "table_id": table_id, + "deleted": deleted, } with transaction.atomic(): - parent_paths = [] - - for parent in get_parents_from_model_query(saved_query, team): - if parent in POSTHOG_ROOT_SOURCES: - parent_model_path, _ = self.get_or_create( - path=[parent], - team=team, - defaults={"deleted": False, "table": None, "saved_query": None}, - ) - parent_paths.append(parent_model_path.path) - - else: - try: - parent_query = DataWarehouseSavedQuery.objects.filter(team=team, name=parent).get() - parent_model_paths = self.filter( - team=team, saved_query=parent_query, path__lquery=f"*.{parent_query.id.hex}" - ).all() - parent_paths.extend(parent_model_path.path for parent_model_path in parent_model_paths) - - except ObjectDoesNotExist: - parent_table = DataWarehouseTable.objects.filter(team=team, name=parent).get() - - # Treat instances of `DataWarehouseTable` as root nodes - parent_model_path, _ = self.get_or_create( - path=[parent_table.id.hex], - team=team, - defaults={"table": parent_table, "deleted": False}, - ) - parent_paths.append(parent_model_path.path) + parent_paths = self.get_or_create_query_parent_paths(query, team=team) results = self.bulk_create( [ - DataWarehouseModelPath(id=uuid7(), path=[*parent_path, saved_query_id.hex], **base_params) - for parent_path in parent_paths + DataWarehouseModelPath(id=uuid7(), path=[*model_path.path, label], **base_params) + for model_path in parent_paths ] ) return results - def update_from_saved_query_instance(self, saved_query: DataWarehouseSavedQuery) -> None: + def get_or_create_root_path_for_posthog_source( + self, posthog_source_name: str, team: Team + ) -> tuple["DataWarehouseModelPath", bool]: + """Get a root path for a PostHog source, creating it if it doesn't exist. + + PostHog sources are well-known PostHog tables. We check against the team's HogQL database + to ensure that the source exists before creating the path. + + Raises: + ValueError: If the provided `posthog_source_name` is not a PostHog table. + + Returns: + A tuple with the model path and a `bool` indicating whether it was created or not. + """ + posthog_tables = self.get_hogql_database(team).get_posthog_tables() + if posthog_source_name not in posthog_tables: + raise ValueError(f"Provided source {posthog_source_name} is not a PostHog table") + + return self.get_or_create( + path=[posthog_source_name], team=team, defaults={"saved_query": None, "deleted": False} + ) + + def get_hogql_database(self, team: Team) -> Database: + """Get the HogQL database for given team.""" + return create_hogql_database(team_id=team.pk, team_arg=team) + + def get_or_create_root_path_for_data_warehouse_table( + self, data_warehouse_table: DataWarehouseTable + ) -> tuple["DataWarehouseModelPath", bool]: + """Get a root path for a `DataWarehouseTable`, creating it if it doesn't exist. + + A `DataWarehouseTable` is loaded by us into S3 or read directly from an external data source, + like our user's S3 bucket or their PostgreSQL database. + + Either way, it is a table we can consider a root node, as it's managed by data warehouse + data import workflows. + + Returns: + A tuple with the model path and a `bool` indicating whether it was created or not. + """ + table_id = data_warehouse_table.id + return self.get_or_create( + path=[table_id.hex], + team=data_warehouse_table.team, + defaults={"saved_query": None, "table": data_warehouse_table, "deleted": False}, + ) + + def filter_all_leaf_paths(self, leaf_id: str | uuid.UUID, team: Team): + """Filter all paths to leaf node given by `leaf_id`.""" + if isinstance(leaf_id, uuid.UUID): + leaf_id = leaf_id.hex + return self.filter(team=team, path__lquery=f"*.{leaf_id}") + + def get_or_create_query_parent_paths(self, query: str, team: Team) -> list["DataWarehouseModelPath"]: + """Get a list of model paths for a query's parents, creating root nodes if they do not exist.""" + parent_paths = [] + for parent in get_parents_from_model_query(query): + try: + parent_path, _ = self.get_or_create_root_path_for_posthog_source(parent, team) + except ValueError: + pass + else: + parent_paths.append(parent_path) + continue + + try: + parent_query = DataWarehouseSavedQuery.objects.filter(team=team, name=parent).get() + except ObjectDoesNotExist: + pass + else: + parent_paths.extend( + parent_path for parent_path in self.filter_all_leaf_paths(parent_query.id.hex, team=team).all() + ) + continue + + try: + parent_table = DataWarehouseTable.objects.filter(team=team, name=parent).get() + except ObjectDoesNotExist: + pass + else: + parent_path, _ = self.get_or_create_root_path_for_data_warehouse_table(parent_table) + parent_paths.append(parent_path) + continue + + raise UnknownParentError(parent, query) + + return parent_paths + + def update_from_saved_query(self, saved_query: DataWarehouseSavedQuery) -> None: """Update model paths from an existing `DataWarehouseSavedQuery`.""" if not self.filter(team=saved_query.team, saved_query=saved_query).exists(): raise ValueError("Provided saved query contains no paths to update.") - # Update descendants - self.update_from_saved_query( - saved_query=saved_query.query["query"], + self.update_paths_from_query( + query=saved_query.query["query"], team=saved_query.team, + label=saved_query.id.hex, saved_query_id=saved_query.id, ) - def update_from_saved_query(self, saved_query: str, team: Team, saved_query_id: uuid.UUID): - parents = get_parents_from_model_query(saved_query, team) - parent_ids = [] + def update_paths_from_query( + self, + query: str, + team: Team, + label: str, + saved_query_id: uuid.UUID | None = None, + table_id: uuid.UUID | None = None, + ) -> None: + """Update all model paths from a given query. + + We parse the query to extract all its direct parents. Then, we update all the paths + that contain `label` to add an edge from parent and `label`, effectively removing the + previous parent path. + + This may lead to duplicate paths, so we have to defer constraints, until the end of + the transaction and clean them up. + """ + parents = get_parents_from_model_query(query) + posthog_tables = self.get_hogql_database(team).get_posthog_tables() + + base_params = { + "team_id": team.pk, + "saved_query_id": saved_query_id, + "table_id": table_id, + } with transaction.atomic(): with connection.cursor() as cursor: + cursor.execute("SET CONSTRAINTS ALL DEFERRED") + for parent in parents: - if parent in POSTHOG_ROOT_SOURCES: + if parent in posthog_tables: parent_id = parent else: try: parent_query = DataWarehouseSavedQuery.objects.filter(team=team, name=parent).get() - parent_id = parent_query.id.hex except ObjectDoesNotExist: - parent_table = DataWarehouseTable.objects.filter(team=team, name=parent).get() - parent_id = parent_table.id.hex - - parent_ids.append(parent_id) + try: + parent_table = DataWarehouseTable.objects.filter(team=team, name=parent).get() + except ObjectDoesNotExist: + raise UnknownParentError(parent, query) + else: + parent_id = parent_table.id.hex + else: + parent_id = parent_query.id.hex - cursor.execute( - INSERT_QUERY, params={"child": saved_query_id.hex, "parent": parent_id, "team_id": team.pk} - ) + cursor.execute(UPDATE_PATHS_QUERY, params={**{"child": label, "parent": parent_id}, **base_params}) cursor.execute(DELETE_DUPLICATE_PATHS_QUERY, params={"team_id": team.pk}) cursor.execute("SET CONSTRAINTS ALL IMMEDIATE") - def get_paths_to_leaf_model( - self, leaf_model: DataWarehouseSavedQuery | DataWarehouseTable - ) -> "models.QuerySet[DataWarehouseModelPath]": - """Return all paths to a leaf model.""" - return self.filter(path__lquery=f"*.{leaf_model.id.hex}").all() - def get_longest_common_ancestor_path( self, leaf_models: collections.abc.Iterable[DataWarehouseSavedQuery | DataWarehouseTable] ) -> str | None: @@ -335,12 +464,12 @@ def get_dag(self, team: Team): node_id: NodeId for model_path in self.filter(team=team).select_related("saved_query", "table").all(): - if model_path.table is not None: - node_type = "Table" - elif model_path.saved_query is not None: - node_type = "SavedQuery" + if model_path.saved_query is not None: + node_type = NodeType.SAVED_QUERY + elif model_path.table is not None: + node_type = NodeType.TABLE else: - node_type = "PostHog" + node_type = NodeType.POSTHOG for index, node_id in enumerate(model_path.path): try: @@ -355,12 +484,16 @@ def get_dag(self, team: Team): class DataWarehouseModelPath(CreatedMetaFields, UpdatedMetaFields, UUIDModel, DeletedMetaFields): - """Represent a path to a model.""" + """Django model to represent paths to a data warehouse model. + + A data warehouse model is represented by a saved query, and the path to it contains all + tables and views that said query is selecting from, recursively all the way to root + PostHog tables and external data source tables. + """ class Meta: indexes = [ models.Index(fields=("team_id", "path"), name="team_id_path"), - models.Index(fields=("team_id", "table_id"), name="team_id_table_id"), models.Index(fields=("team_id", "saved_query_id"), name="team_id_saved_query_id"), pg_indexes.GistIndex("path", name="model_path_path"), ] diff --git a/posthog/warehouse/models/test/test_modeling.py b/posthog/warehouse/models/test/test_modeling.py index 5c66d3d70026dd..e5a2c2c8bc3bcf 100644 --- a/posthog/warehouse/models/test/test_modeling.py +++ b/posthog/warehouse/models/test/test_modeling.py @@ -1,6 +1,25 @@ +import pytest + from posthog.test.base import BaseTest from posthog.warehouse.models.datawarehouse_saved_query import DataWarehouseSavedQuery -from posthog.warehouse.models.modeling import DataWarehouseModelPath +from posthog.warehouse.models.modeling import DataWarehouseModelPath, NodeType, get_parents_from_model_query + + +@pytest.mark.parametrize( + "query,parents", + [ + ("select * from events, persons", {"events", "persons"}), + ("select * from some_random_view", {"some_random_view"}), + ( + "with cte as (select * from events), cte2 as (select * from cte), cte3 as (select 1) select * from cte2", + {"events"}, + ), + ("select 1", set()), + ], +) +def test_get_parents_from_model_query(query: str, parents: set[str]): + """Test parents are correctly parsed from sample queries.""" + assert parents == get_parents_from_model_query(query) class TestModelPath(BaseTest): @@ -20,7 +39,7 @@ def test_create_from_posthog_root_nodes_query(self): query={"query": query}, ) - model_paths = DataWarehouseModelPath.objects.create_from_saved_query_instance(saved_query) + model_paths = DataWarehouseModelPath.objects.create_from_saved_query(saved_query) paths = [model_path.path for model_path in model_paths] self.assertEqual(len(paths), 2) @@ -48,8 +67,8 @@ def test_create_from_existing_path(self): query={"query": "select * from my_model as my_other_model"}, ) - parent_model_paths = DataWarehouseModelPath.objects.create_from_saved_query_instance(parent_saved_query) - child_model_paths = DataWarehouseModelPath.objects.create_from_saved_query_instance(child_saved_query) + parent_model_paths = DataWarehouseModelPath.objects.create_from_saved_query(parent_saved_query) + child_model_paths = DataWarehouseModelPath.objects.create_from_saved_query(child_saved_query) parent_paths = [model_path.path for model_path in parent_model_paths] child_paths = [model_path.path for model_path in child_model_paths] @@ -88,13 +107,13 @@ def test_update_path_from_saved_query(self): query={"query": "select * from my_model_child"}, ) - DataWarehouseModelPath.objects.create_from_saved_query_instance(parent_saved_query) - DataWarehouseModelPath.objects.create_from_saved_query_instance(child_saved_query) - DataWarehouseModelPath.objects.create_from_saved_query_instance(grand_child_saved_query) + DataWarehouseModelPath.objects.create_from_saved_query(parent_saved_query) + DataWarehouseModelPath.objects.create_from_saved_query(child_saved_query) + DataWarehouseModelPath.objects.create_from_saved_query(grand_child_saved_query) child_saved_query.query = {"query": "select * from events as my_other_model"} child_saved_query.save() - DataWarehouseModelPath.objects.update_from_saved_query_instance(child_saved_query) + DataWarehouseModelPath.objects.update_from_saved_query(child_saved_query) child_refreshed_model_paths = DataWarehouseModelPath.objects.filter( team=self.team, saved_query=child_saved_query @@ -135,8 +154,8 @@ def test_get_longest_common_ancestor_path(self): name="my_model2", query={"query": query_2}, ) - DataWarehouseModelPath.objects.create_from_saved_query_instance(saved_query_1) - DataWarehouseModelPath.objects.create_from_saved_query_instance(saved_query_2) + DataWarehouseModelPath.objects.create_from_saved_query(saved_query_1) + DataWarehouseModelPath.objects.create_from_saved_query(saved_query_2) lca = DataWarehouseModelPath.objects.get_longest_common_ancestor_path([saved_query_1, saved_query_2]) self.assertEqual(lca, "events") @@ -162,8 +181,8 @@ def test_get_dag(self): query={"query": "select * from my_model as my_other_model"}, ) - DataWarehouseModelPath.objects.create_from_saved_query_instance(parent_saved_query) - DataWarehouseModelPath.objects.create_from_saved_query_instance(child_saved_query) + DataWarehouseModelPath.objects.create_from_saved_query(parent_saved_query) + DataWarehouseModelPath.objects.create_from_saved_query(child_saved_query) dag = DataWarehouseModelPath.objects.get_dag(team=self.team) @@ -172,8 +191,8 @@ def test_get_dag(self): self.assertIn(("persons", parent_saved_query.id.hex), dag.edges) self.assertEqual(len(dag.edges), 3) - self.assertIn((child_saved_query.id.hex, "SavedQuery"), dag.nodes) - self.assertIn((parent_saved_query.id.hex, "SavedQuery"), dag.nodes) - self.assertIn(("events", "PostHog"), dag.nodes) - self.assertIn(("persons", "PostHog"), dag.nodes) + self.assertIn((child_saved_query.id.hex, NodeType.SAVED_QUERY), dag.nodes) + self.assertIn((parent_saved_query.id.hex, NodeType.SAVED_QUERY), dag.nodes) + self.assertIn(("events", NodeType.POSTHOG), dag.nodes) + self.assertIn(("persons", NodeType.POSTHOG), dag.nodes) self.assertEqual(len(dag.nodes), 4)