From ec15c510ff4691078e169bbfebc47556e9f4c671 Mon Sep 17 00:00:00 2001 From: Eric Duong Date: Fri, 28 Jul 2023 10:35:54 -0400 Subject: [PATCH] feat(data-warehouse): Views backend (#16573) * backend basics for a view * view parsing "working" * add tests and make view name unique * adjust tests * api tests and edge cases * typing * try comment * fix migration check * rename * block overlapping names at api level * add model validator * more naming changes and full integration test * update migration * update migration * remove repeat * ClickHouse * regex * update validator name * update migration * update constraint * casing --------- Co-authored-by: Michael Matloka --- latest_migrations.manifest | 2 +- posthog/api/__init__.py | 3 +- .../api/test/__snapshots__/test_query.ambr | 32 +++++ posthog/api/test/test_query.py | 47 ++++++++ posthog/hogql/database/database.py | 19 ++- posthog/hogql/database/models.py | 9 ++ posthog/hogql/database/test/tables.py | 40 +++++++ .../hogql/database/test/test_saved_query.py | 65 ++++++++++ posthog/hogql/printer.py | 9 +- posthog/hogql/resolver.py | 10 +- .../commands/test_migrations_are_safe.py | 2 +- .../0338_datawarehouse_saved_query.py | 60 ++++++++++ posthog/warehouse/api/saved_query.py | 70 +++++++++++ .../warehouse/api/test/test_saved_query.py | 113 ++++++++++++++++++ posthog/warehouse/models/__init__.py | 1 + .../models/datawarehouse_saved_query.py | 71 +++++++++++ posthog/warehouse/models/table.py | 21 +--- posthog/warehouse/models/util.py | 14 +++ 18 files changed, 564 insertions(+), 24 deletions(-) create mode 100644 posthog/hogql/database/test/test_saved_query.py create mode 100644 posthog/migrations/0338_datawarehouse_saved_query.py create mode 100644 posthog/warehouse/api/saved_query.py create mode 100644 posthog/warehouse/api/test/test_saved_query.py create mode 100644 posthog/warehouse/models/datawarehouse_saved_query.py create mode 100644 posthog/warehouse/models/util.py diff --git a/latest_migrations.manifest b/latest_migrations.manifest index fa38b719754e0..1cf39de80f10d 100644 --- a/latest_migrations.manifest +++ b/latest_migrations.manifest @@ -5,7 +5,7 @@ contenttypes: 0002_remove_content_type_name ee: 0015_add_verified_properties otp_static: 0002_throttling otp_totp: 0002_auto_20190420_0723 -posthog: 0337_more_session_recording_fields +posthog: 0338_datawarehouse_saved_query sessions: 0001_initial social_django: 0010_uid_db_index two_factor: 0007_auto_20201201_1019 diff --git a/posthog/api/__init__.py b/posthog/api/__init__.py index 8fcf828830ad1..1aa9a4654ec76 100644 --- a/posthog/api/__init__.py +++ b/posthog/api/__init__.py @@ -3,7 +3,7 @@ from posthog.api.routing import DefaultRouterPlusPlus from posthog.batch_exports import http as batch_exports from posthog.settings import EE_AVAILABLE -from posthog.warehouse.api import table +from posthog.warehouse.api import saved_query, table from . import ( activity_log, @@ -135,6 +135,7 @@ def api_not_found(request): batch_exports_router.register(r"runs", batch_exports.BatchExportRunViewSet, "runs", ["team_id", "batch_export_id"]) projects_router.register(r"warehouse_table", table.TableViewSet, "warehouse_api", ["team_id"]) +projects_router.register(r"warehouse_view", saved_query.DataWarehouseSavedQueryViewSet, "warehouse_api", ["team_id"]) # Organizations nested endpoints organizations_router = router.register(r"organizations", organization.OrganizationViewSet, "organizations") diff --git a/posthog/api/test/__snapshots__/test_query.ambr b/posthog/api/test/__snapshots__/test_query.ambr index e531281fa3a97..d85b3889be37c 100644 --- a/posthog/api/test/__snapshots__/test_query.ambr +++ b/posthog/api/test/__snapshots__/test_query.ambr @@ -169,6 +169,38 @@ allow_experimental_object_type=True ' --- +# name: TestQuery.test_full_hogql_query_view + ' + /* user_id:0 request:_snapshot_ */ + SELECT events.event AS event, + events.distinct_id AS distinct_id, + replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key + FROM events + WHERE equals(events.team_id, 2) + ORDER BY toTimeZone(events.timestamp, 'UTC') ASC + LIMIT 100 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=True + ' +--- +# name: TestQuery.test_full_hogql_query_view.1 + ' + /* user_id:0 request:_snapshot_ */ + SELECT event_view.event, + event_view.distinct_id, + event_view.key + FROM + (SELECT events.event AS event, + events.distinct_id AS distinct_id, + replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key + FROM events + WHERE equals(events.team_id, 2) + ORDER BY toTimeZone(events.timestamp, 'UTC') ASC) AS event_view + LIMIT 100 SETTINGS readonly=2, + max_execution_time=60, + allow_experimental_object_type=True + ' +--- # name: TestQuery.test_hogql_property_filter ' /* user_id:0 request:_snapshot_ */ diff --git a/posthog/api/test/test_query.py b/posthog/api/test/test_query.py index 383e7a032a81d..222c7523db994 100644 --- a/posthog/api/test/test_query.py +++ b/posthog/api/test/test_query.py @@ -473,3 +473,50 @@ def test_invalid_query_kind(self): assert api_response.json()["code"] == "parse_error" assert "validation errors for Model" in api_response.json()["detail"] assert "type=value_error.const; given=Tomato Soup" in api_response.json()["detail"] + + @snapshot_clickhouse_queries + def test_full_hogql_query_view(self): + with freeze_time("2020-01-10 12:00:00"): + _create_person( + properties={"email": "tom@posthog.com"}, + distinct_ids=["2", "some-random-uid"], + team=self.team, + immediate=True, + ) + _create_event(team=self.team, event="sign up", distinct_id="2", properties={"key": "test_val1"}) + with freeze_time("2020-01-10 12:11:00"): + _create_event(team=self.team, event="sign out", distinct_id="2", properties={"key": "test_val2"}) + with freeze_time("2020-01-10 12:12:00"): + _create_event(team=self.team, event="sign out", distinct_id="3", properties={"key": "test_val2"}) + with freeze_time("2020-01-10 12:13:00"): + _create_event( + team=self.team, event="sign out", distinct_id="4", properties={"key": "test_val3", "path": "a/b/c"} + ) + flush_persons_and_events() + + with freeze_time("2020-01-10 12:14:00"): + + self.client.post( + f"/api/projects/{self.team.id}/warehouse_view/", + { + "name": "event_view", + "query": { + "kind": "HogQLQuery", + "query": f"select event AS event, distinct_id as distinct_id, properties.key as key from events order by timestamp", + }, + }, + ) + query = HogQLQuery(query="select * from event_view") + api_response = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": query.dict()}).json() + query.response = HogQLQueryResponse.parse_obj(api_response) + + self.assertEqual(query.response.results and len(query.response.results), 4) + self.assertEqual( + query.response.results, + [ + ["sign up", "2", "test_val1"], + ["sign out", "2", "test_val2"], + ["sign out", "3", "test_val2"], + ["sign out", "4", "test_val3"], + ], + ) diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index c5423f5afd976..80f9c7a3cfca1 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -52,6 +52,19 @@ class Config: raw_cohort_people: RawCohortPeople = RawCohortPeople() raw_person_overrides: RawPersonOverridesTable = RawPersonOverridesTable() + # clunky: keep table names in sync with above + _table_names: List[Table] = [ + "events", + "groups", + "person", + "person_distinct_id2", + "person_overrides", + "session_recording_events", + "session_replay_events", + "cohortpeople", + "person_static_cohort", + ] + def __init__(self, timezone: Optional[str]): super().__init__() try: @@ -77,7 +90,7 @@ def add_warehouse_tables(self, **field_definitions: Any): def create_hogql_database(team_id: int) -> Database: from posthog.models import Team - from posthog.warehouse.models import DataWarehouseTable + from posthog.warehouse.models import DataWarehouseTable, DataWarehouseSavedQuery team = Team.objects.get(pk=team_id) database = Database(timezone=team.timezone) @@ -89,6 +102,10 @@ def create_hogql_database(team_id: int) -> Database: tables = {} for table in DataWarehouseTable.objects.filter(team_id=team.pk).exclude(deleted=True): tables[table.name] = table.hogql_definition() + + for table in DataWarehouseSavedQuery.objects.filter(team_id=team.pk).exclude(deleted=True): + tables[table.name] = table.hogql_definition() + database.add_warehouse_tables(**tables) return database diff --git a/posthog/hogql/database/models.py b/posthog/hogql/database/models.py index ee2febeb38bcf..e326808e35f2b 100644 --- a/posthog/hogql/database/models.py +++ b/posthog/hogql/database/models.py @@ -138,3 +138,12 @@ class FunctionCallTable(Table): """ name: str + + +class SavedQuery(Table): + """ + A table that returns a subquery, e.g. my_saved_query -> (SELECT * FROM some_saved_table). The team_id guard is NOT added for the overall subquery + """ + + query: str + name: str diff --git a/posthog/hogql/database/test/tables.py b/posthog/hogql/database/test/tables.py index 7b67c4cbdde97..f675f3c8d194d 100644 --- a/posthog/hogql/database/test/tables.py +++ b/posthog/hogql/database/test/tables.py @@ -1,5 +1,6 @@ from posthog.hogql.database.models import DateDatabaseField, IntegerDatabaseField, FloatDatabaseField from posthog.hogql.database.s3_table import S3Table +from posthog.hogql.database.models import SavedQuery def create_aapl_stock_s3_table(name="aapl_stock") -> S3Table: @@ -17,3 +18,42 @@ def create_aapl_stock_s3_table(name="aapl_stock") -> S3Table: "OpenInt": IntegerDatabaseField(name="OpenInt"), }, ) + + +def create_aapl_stock_table_view() -> SavedQuery: + return SavedQuery( + name="aapl_stock_view", + query="SELECT * FROM aapl_stock", + fields={ + "Date": DateDatabaseField(name="Date"), + "Open": FloatDatabaseField(name="Open"), + "High": FloatDatabaseField(name="High"), + "Low": FloatDatabaseField(name="Low"), + }, + ) + + +def create_nested_aapl_stock_view() -> SavedQuery: + return SavedQuery( + name="aapl_stock_nested_view", + query="SELECT * FROM aapl_stock_view", + fields={ + "Date": DateDatabaseField(name="Date"), + "Open": FloatDatabaseField(name="Open"), + "High": FloatDatabaseField(name="High"), + "Low": FloatDatabaseField(name="Low"), + }, + ) + + +def create_aapl_stock_table_self_referencing() -> SavedQuery: + return SavedQuery( + name="aapl_stock_self", + query="SELECT * FROM aapl_stock_self", + fields={ + "Date": DateDatabaseField(name="Date"), + "Open": FloatDatabaseField(name="Open"), + "High": FloatDatabaseField(name="High"), + "Low": FloatDatabaseField(name="Low"), + }, + ) diff --git a/posthog/hogql/database/test/test_saved_query.py b/posthog/hogql/database/test/test_saved_query.py new file mode 100644 index 0000000000000..8d823db9148fc --- /dev/null +++ b/posthog/hogql/database/test/test_saved_query.py @@ -0,0 +1,65 @@ +from posthog.hogql.context import HogQLContext +from posthog.hogql.database.database import create_hogql_database +from posthog.hogql.parser import parse_select +from posthog.hogql.printer import print_ast +from posthog.test.base import BaseTest +from posthog.hogql.database.test.tables import ( + create_aapl_stock_table_view, + create_aapl_stock_s3_table, + create_nested_aapl_stock_view, + create_aapl_stock_table_self_referencing, +) + + +class TestSavedQuery(BaseTest): + maxDiff = None + + def _init_database(self): + self.database = create_hogql_database(self.team.pk) + self.database.aapl_stock_view = create_aapl_stock_table_view() + self.database.aapl_stock = create_aapl_stock_s3_table() + self.database.aapl_stock_nested_view = create_nested_aapl_stock_view() + self.database.aapl_stock_self = create_aapl_stock_table_self_referencing() + self.context = HogQLContext(team_id=self.team.pk, enable_select_queries=True, database=self.database) + + def _select(self, query: str, dialect: str = "clickhouse") -> str: + return print_ast(parse_select(query), self.context, dialect=dialect) + + def test_saved_query_table_select(self): + self._init_database() + + hogql = self._select(query="SELECT * FROM aapl_stock LIMIT 10", dialect="hogql") + self.assertEqual(hogql, "SELECT Date, Open, High, Low, Close, Volume, OpenInt FROM aapl_stock LIMIT 10") + + clickhouse = self._select(query="SELECT * FROM aapl_stock_view LIMIT 10", dialect="clickhouse") + + self.assertEqual( + clickhouse, + "SELECT aapl_stock_view.Date, aapl_stock_view.Open, aapl_stock_view.High, aapl_stock_view.Low, aapl_stock_view.Close, aapl_stock_view.Volume, aapl_stock_view.OpenInt FROM (WITH aapl_stock AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_0_sensitive)s, %(hogql_val_1)s)) SELECT aapl_stock.Date, aapl_stock.Open, aapl_stock.High, aapl_stock.Low, aapl_stock.Close, aapl_stock.Volume, aapl_stock.OpenInt FROM aapl_stock) AS aapl_stock_view LIMIT 10", + ) + + def test_nested_saved_queries(self): + self._init_database() + + hogql = self._select(query="SELECT * FROM aapl_stock LIMIT 10", dialect="hogql") + self.assertEqual(hogql, "SELECT Date, Open, High, Low, Close, Volume, OpenInt FROM aapl_stock LIMIT 10") + + clickhouse = self._select(query="SELECT * FROM aapl_stock_nested_view LIMIT 10", dialect="clickhouse") + + self.assertEqual( + clickhouse, + "SELECT aapl_stock_nested_view.Date, aapl_stock_nested_view.Open, aapl_stock_nested_view.High, aapl_stock_nested_view.Low, aapl_stock_nested_view.Close, aapl_stock_nested_view.Volume, aapl_stock_nested_view.OpenInt FROM (SELECT aapl_stock_view.Date, aapl_stock_view.Open, aapl_stock_view.High, aapl_stock_view.Low, aapl_stock_view.Close, aapl_stock_view.Volume, aapl_stock_view.OpenInt FROM (WITH aapl_stock AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_0_sensitive)s, %(hogql_val_1)s)) SELECT aapl_stock.Date, aapl_stock.Open, aapl_stock.High, aapl_stock.Low, aapl_stock.Close, aapl_stock.Volume, aapl_stock.OpenInt FROM aapl_stock) AS aapl_stock_view) AS aapl_stock_nested_view LIMIT 10", + ) + + def test_saved_query_with_alias(self): + self._init_database() + + hogql = self._select(query="SELECT * FROM aapl_stock LIMIT 10", dialect="hogql") + self.assertEqual(hogql, "SELECT Date, Open, High, Low, Close, Volume, OpenInt FROM aapl_stock LIMIT 10") + + clickhouse = self._select(query="SELECT * FROM aapl_stock_view AS some_alias LIMIT 10", dialect="clickhouse") + + self.assertEqual( + clickhouse, + "SELECT some_alias.Date, some_alias.Open, some_alias.High, some_alias.Low, some_alias.Close, some_alias.Volume, some_alias.OpenInt FROM (WITH aapl_stock AS (SELECT * FROM s3Cluster('posthog', %(hogql_val_0_sensitive)s, %(hogql_val_1)s)) SELECT aapl_stock.Date, aapl_stock.Open, aapl_stock.High, aapl_stock.Low, aapl_stock.Close, aapl_stock.Volume, aapl_stock.OpenInt FROM aapl_stock) AS some_alias LIMIT 10", + ) diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index ad82ebf266065..c1c4d718c540c 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -19,7 +19,7 @@ HOGQL_POSTHOG_FUNCTIONS, ) from posthog.hogql.context import HogQLContext -from posthog.hogql.database.models import Table, FunctionCallTable +from posthog.hogql.database.models import Table, FunctionCallTable, SavedQuery from posthog.hogql.database.database import create_hogql_database from posthog.hogql.database.s3_table import S3Table from posthog.hogql.errors import HogQLException @@ -75,7 +75,6 @@ def prepare_ast_for_printing( if dialect == "clickhouse": node = resolve_property_types(node, context) resolve_lazy_tables(node, stack, context) - # We add a team_id guard right before printing. It's not a separate step here. return node @@ -259,7 +258,11 @@ def visit_join_expr(self, node: ast.JoinExpr) -> JoinExprResponse: # :IMPORTANT: This assures a "team_id" where clause is present on every selected table. # Skip function call tables like numbers(), s3(), etc. - if self.dialect == "clickhouse" and not isinstance(table_type.table, FunctionCallTable): + if ( + self.dialect == "clickhouse" + and not isinstance(table_type.table, FunctionCallTable) + and not isinstance(table_type.table, SavedQuery) + ): extra_where = team_id_guard_for_table(node.type, self.context) if self.dialect == "clickhouse": diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py index 63c055abbe603..08817e8f66637 100644 --- a/posthog/hogql/resolver.py +++ b/posthog/hogql/resolver.py @@ -6,11 +6,12 @@ from posthog.hogql.ast import FieldTraverserType, ConstantType from posthog.hogql.functions import HOGQL_POSTHOG_FUNCTIONS from posthog.hogql.context import HogQLContext -from posthog.hogql.database.models import StringJSONDatabaseField, FunctionCallTable, LazyTable +from posthog.hogql.database.models import StringJSONDatabaseField, FunctionCallTable, LazyTable, SavedQuery from posthog.hogql.errors import ResolverException from posthog.hogql.functions.cohort import cohort from posthog.hogql.functions.mapping import validate_function_args from posthog.hogql.functions.sparkline import sparkline +from posthog.hogql.parser import parse_select from posthog.hogql.visitor import CloningVisitor, clone_expr from posthog.models.utils import UUIDT @@ -205,6 +206,13 @@ def visit_join_expr(self, node: ast.JoinExpr): if self.database.has_table(table_name): database_table = self.database.get_table(table_name) + + if isinstance(database_table, SavedQuery): + node.table = parse_select(str(database_table.query)) + node.alias = table_alias or database_table.name + node = self.visit(node) + return node + if isinstance(database_table, LazyTable): node_table_type = ast.LazyTableType(table=database_table) else: diff --git a/posthog/management/commands/test_migrations_are_safe.py b/posthog/management/commands/test_migrations_are_safe.py index 515cc8dee7318..c8ecbe2792ebe 100644 --- a/posthog/management/commands/test_migrations_are_safe.py +++ b/posthog/management/commands/test_migrations_are_safe.py @@ -67,7 +67,7 @@ def run_and_check_migration(variable): sys.exit(1) if "CONSTRAINT" in operation_sql and ( "-- existing-table-constraint-ignore" not in operation_sql - or ( + and ( table_being_altered not in tables_created_so_far or self._get_table("ALTER TABLE", operation_sql) not in new_tables ) # Ignore for brand-new tables diff --git a/posthog/migrations/0338_datawarehouse_saved_query.py b/posthog/migrations/0338_datawarehouse_saved_query.py new file mode 100644 index 0000000000000..eac5feef35a87 --- /dev/null +++ b/posthog/migrations/0338_datawarehouse_saved_query.py @@ -0,0 +1,60 @@ +# Generated by Django 3.2.19 on 2023-07-27 20:32 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import posthog.models.utils +import posthog.warehouse.models.datawarehouse_saved_query + + +class Migration(migrations.Migration): + + dependencies = [ + ("posthog", "0337_more_session_recording_fields"), + ] + + operations = [ + migrations.CreateModel( + name="DataWarehouseSavedQuery", + fields=[ + ("created_at", models.DateTimeField(auto_now_add=True)), + ("deleted", models.BooleanField(blank=True, null=True)), + ( + "id", + models.UUIDField( + default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False + ), + ), + ( + "name", + models.CharField( + max_length=128, + validators=[posthog.warehouse.models.datawarehouse_saved_query.validate_saved_query_name], + ), + ), + ( + "columns", + models.JSONField( + blank=True, + default=dict, + help_text="Dict of all columns with ClickHouse type (including Nullable())", + null=True, + ), + ), + ("query", models.JSONField(blank=True, default=dict, help_text="HogQL query", null=True)), + ( + "created_by", + models.ForeignKey( + blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL + ), + ), + ("team", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="posthog.team")), + ], + ), + migrations.AddConstraint( + model_name="datawarehousesavedquery", + constraint=models.UniqueConstraint( + fields=("team", "name"), name="posthog_datawarehouse_saved_query_unique_name" + ), + ), + ] diff --git a/posthog/warehouse/api/saved_query.py b/posthog/warehouse/api/saved_query.py new file mode 100644 index 0000000000000..f660c2c2a9133 --- /dev/null +++ b/posthog/warehouse/api/saved_query.py @@ -0,0 +1,70 @@ +from posthog.permissions import OrganizationMemberPermissions +from rest_framework.exceptions import NotAuthenticated +from rest_framework.permissions import IsAuthenticated +from rest_framework import filters, serializers, viewsets +from posthog.warehouse.models import DataWarehouseSavedQuery +from posthog.api.shared import UserBasicSerializer +from posthog.api.routing import StructuredViewSetMixin + +from posthog.models import User +from typing import Any + + +class DataWarehouseSavedQuerySerializer(serializers.ModelSerializer): + created_by = UserBasicSerializer(read_only=True) + + class Meta: + model = DataWarehouseSavedQuery + fields = ["id", "deleted", "name", "query", "created_by", "created_at", "columns"] + read_only_fields = ["id", "created_by", "created_at", "columns"] + + def create(self, validated_data): + validated_data["team_id"] = self.context["team_id"] + validated_data["created_by"] = self.context["request"].user + + view = DataWarehouseSavedQuery(**validated_data) + # The columns will be inferred from the query + try: + view.columns = view.get_columns() + except Exception as err: + raise serializers.ValidationError(str(err)) + + view.save() + return view + + def update(self, instance: Any, validated_data: Any) -> Any: + view = super().update(instance, validated_data) + + try: + view.columns = view.get_columns() + except Exception as err: + raise serializers.ValidationError(str(err)) + view.save() + return view + + +class DataWarehouseSavedQueryViewSet(StructuredViewSetMixin, viewsets.ModelViewSet): + """ + Create, Read, Update and Delete Warehouse Tables. + """ + + queryset = DataWarehouseSavedQuery.objects.all() + serializer_class = DataWarehouseSavedQuerySerializer + permission_classes = [IsAuthenticated, OrganizationMemberPermissions] + filter_backends = [filters.SearchFilter] + search_fields = ["name"] + ordering = "-created_at" + + def get_queryset(self): + if not isinstance(self.request.user, User) or self.request.user.current_team is None: + raise NotAuthenticated() + + if self.action == "list": + return ( + self.queryset.filter(team_id=self.team_id) + .exclude(deleted=True) + .prefetch_related("created_by") + .order_by(self.ordering) + ) + + return self.queryset.filter(team_id=self.team_id).prefetch_related("created_by").order_by(self.ordering) diff --git a/posthog/warehouse/api/test/test_saved_query.py b/posthog/warehouse/api/test/test_saved_query.py new file mode 100644 index 0000000000000..bc14c97988493 --- /dev/null +++ b/posthog/warehouse/api/test/test_saved_query.py @@ -0,0 +1,113 @@ +from posthog.test.base import ( + APIBaseTest, +) + + +class TestSavedQuery(APIBaseTest): + def test_create(self): + response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_view/", + { + "name": "event_view", + "query": { + "kind": "HogQLQuery", + "query": f"select event from events LIMIT 100", + }, + }, + ) + self.assertEqual(response.status_code, 201, response.content) + saved_query = response.json() + self.assertEqual(saved_query["name"], "event_view") + self.assertEqual(saved_query["columns"], {"event": "String"}) + + def test_create_name_overlap_error(self): + response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_view/", + { + "name": "events", + "query": { + "kind": "HogQLQuery", + "query": f"select event from events LIMIT 100", + }, + }, + ) + self.assertEqual(response.status_code, 400, response.content) + + def test_saved_query_doesnt_exist(self): + saved_query_1_response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_view/", + { + "name": "event_view", + "query": { + "kind": "HogQLQuery", + "query": f"select * from event_view LIMIT 100", + }, + }, + ) + self.assertEqual(saved_query_1_response.status_code, 400, saved_query_1_response.content) + + def test_view_updated(self): + response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_view/", + { + "name": "event_view", + "query": { + "kind": "HogQLQuery", + "query": f"select event from events LIMIT 100", + }, + }, + ) + self.assertEqual(response.status_code, 201, response.content) + saved_query_1_response = response.json() + saved_query_1_response = self.client.patch( + f"/api/projects/{self.team.id}/warehouse_view/" + saved_query_1_response["id"], + { + "query": { + "kind": "HogQLQuery", + "query": f"select distinct_id from events LIMIT 100", + }, + }, + ) + + self.assertEqual(saved_query_1_response.status_code, 200, saved_query_1_response.content) + view_1 = saved_query_1_response.json() + self.assertEqual(view_1["name"], "event_view") + self.assertEqual(view_1["columns"], {"distinct_id": "String"}) + + def test_circular_view(self): + saved_query_1_response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_view/", + { + "name": "event_view", + "query": { + "kind": "HogQLQuery", + "query": f"select * from events LIMIT 100", + }, + }, + ) + self.assertEqual(saved_query_1_response.status_code, 201, saved_query_1_response.content) + saved_query_1 = saved_query_1_response.json() + + saved_view_2_response = self.client.post( + f"/api/projects/{self.team.id}/warehouse_view/", + { + "name": "outer_event_view", + "query": { + "kind": "HogQLQuery", + "query": f"select event from event_view LIMIT 100", + }, + }, + ) + self.assertEqual(saved_view_2_response.status_code, 201, saved_view_2_response.content) + + saved_view_1_response = self.client.patch( + f"/api/projects/{self.team.id}/warehouse_view/" + saved_query_1["id"], + { + "name": "event_view", + "query": { + "kind": "HogQLQuery", + "query": f"select * from outer_event_view LIMIT 100", + }, + }, + ) + self.assertEqual(saved_view_1_response.status_code, 400, saved_view_1_response.content) diff --git a/posthog/warehouse/models/__init__.py b/posthog/warehouse/models/__init__.py index d879758fea39c..c37373d9f24af 100644 --- a/posthog/warehouse/models/__init__.py +++ b/posthog/warehouse/models/__init__.py @@ -1,2 +1,3 @@ from .table import * from .credential import * +from .datawarehouse_saved_query import * diff --git a/posthog/warehouse/models/datawarehouse_saved_query.py b/posthog/warehouse/models/datawarehouse_saved_query.py new file mode 100644 index 0000000000000..823a608253716 --- /dev/null +++ b/posthog/warehouse/models/datawarehouse_saved_query.py @@ -0,0 +1,71 @@ +from posthog.models.utils import UUIDModel, CreatedMetaFields, DeletedMetaFields +from django.db import models +from posthog.models.team import Team + +from posthog.hogql.database.models import SavedQuery +from posthog.hogql.database.database import Database +from typing import Dict +import re +from django.core.exceptions import ValidationError +from posthog.warehouse.models.util import remove_named_tuples + + +def validate_saved_query_name(value): + if not re.match(r"^[A-Za-z_$][A-Za-z0-9_$]*$", value): + raise ValidationError( + f"{value} is not a valid view name. View names can only contain letters, numbers, '_', or '$' ", + params={"value": value}, + ) + + if value in Database._table_names: + raise ValidationError( + f"{value} is not a valid view name. View names cannot overlap with PostHog table names.", + params={"value": value}, + ) + + +class DataWarehouseSavedQuery(CreatedMetaFields, UUIDModel, DeletedMetaFields): + name: models.CharField = models.CharField(max_length=128, validators=[validate_saved_query_name]) + team: models.ForeignKey = models.ForeignKey(Team, on_delete=models.CASCADE) + columns: models.JSONField = models.JSONField( + default=dict, null=True, blank=True, help_text="Dict of all columns with ClickHouse type (including Nullable())" + ) + query: models.JSONField = models.JSONField(default=dict, null=True, blank=True, help_text="HogQL query") + + class Meta: + constraints = [ + models.UniqueConstraint(fields=["team", "name"], name="posthog_datawarehouse_saved_query_unique_name") + ] + + def get_columns(self) -> Dict[str, str]: + from posthog.api.query import process_query + + # TODO: catch and raise error + response = process_query(self.team, self.query) + types = response.get("types", {}) + return dict(types) + + def hogql_definition(self) -> SavedQuery: + from posthog.warehouse.models.table import CLICKHOUSE_HOGQL_MAPPING + + if not self.columns: + raise Exception("Columns must be fetched and saved to use in HogQL.") + + fields = {} + for column, type in self.columns.items(): + if type.startswith("Nullable("): + type = type.replace("Nullable(", "")[:-1] + + # TODO: remove when addressed https://github.com/ClickHouse/ClickHouse/issues/37594 + if type.startswith("Array("): + type = remove_named_tuples(type) + + type = type.partition("(")[0] + type = CLICKHOUSE_HOGQL_MAPPING[type] + fields[column] = type(name=column) + + return SavedQuery( + name=self.name, + query=self.query["query"], + fields=fields, + ) diff --git a/posthog/warehouse/models/table.py b/posthog/warehouse/models/table.py index 02fa74d9997af..10e61444e8250 100644 --- a/posthog/warehouse/models/table.py +++ b/posthog/warehouse/models/table.py @@ -13,9 +13,10 @@ StringArrayDatabaseField, ) from posthog.hogql.database.s3_table import S3Table -import re +from posthog.warehouse.models.util import remove_named_tuples -ClickhouseHogqlMapping = { +CLICKHOUSE_HOGQL_MAPPING = { + "UUID": StringDatabaseField, "String": StringDatabaseField, "DateTime64": DateTimeDatabaseField, "DateTime32": DateTimeDatabaseField, @@ -93,11 +94,11 @@ def hogql_definition(self) -> S3Table: # TODO: remove when addressed https://github.com/ClickHouse/ClickHouse/issues/37594 if type.startswith("Array("): - type = self.remove_named_tuples(type) + type = remove_named_tuples(type) structure.append(f"{column} {type}") type = type.partition("(")[0] - type = ClickhouseHogqlMapping[type] + type = CLICKHOUSE_HOGQL_MAPPING[type] fields[column] = type(name=column) return S3Table( @@ -110,18 +111,6 @@ def hogql_definition(self) -> S3Table: structure=", ".join(structure), ) - def remove_named_tuples(self, type): - """Remove named tuples from query""" - tokenified_type = re.split(r"(\W)", type) - filtered_tokens = [ - token - for token in tokenified_type - if token == "Nullable" - or (len(token) == 1 and not token.isalnum()) - or token in ClickhouseHogqlMapping.keys() - ] - return "".join(filtered_tokens) - def _safe_expose_ch_error(self, err): err = wrap_query_error(err) for key, value in ExtractErrors.items(): diff --git a/posthog/warehouse/models/util.py b/posthog/warehouse/models/util.py new file mode 100644 index 0000000000000..d33cb47b3cbf5 --- /dev/null +++ b/posthog/warehouse/models/util.py @@ -0,0 +1,14 @@ +import re + + +def remove_named_tuples(type): + """Remove named tuples from query""" + from posthog.warehouse.models.table import CLICKHOUSE_HOGQL_MAPPING + + tokenified_type = re.split(r"(\W)", type) + filtered_tokens = [ + token + for token in tokenified_type + if token == "Nullable" or (len(token) == 1 and not token.isalnum()) or token in CLICKHOUSE_HOGQL_MAPPING.keys() + ] + return "".join(filtered_tokens)