Skip to content

Commit

Permalink
fix: Allow updating batch export with HogQL query (#20114)
Browse files Browse the repository at this point in the history
* fix: Allow updating batch export with HogQL query

* refactor: Save hogql query in schema

* Update query snapshots

* Update query snapshots

---------

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
tomasfarias and github-actions[bot] authored Feb 5, 2024
1 parent 5e7a18a commit c6cf384
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 44 deletions.
3 changes: 3 additions & 0 deletions posthog/api/test/batch_exports/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ def test_create_batch_export_with_custom_schema(client: HttpClient):
assert response.status_code == status.HTTP_201_CREATED, response.json()

data = response.json()
expected_hogql_query = " ".join(TEST_HOGQL_QUERY.split()) # Don't care about whitespace
assert data["schema"]["hogql_query"] == expected_hogql_query

codec = EncryptionCodec(settings=settings)
schedule = describe_schedule(temporal, data["id"])
Expand Down Expand Up @@ -288,6 +290,7 @@ def test_create_batch_export_with_custom_schema(client: HttpClient):
"hogql_val_0": "$browser",
"hogql_val_1": "custom",
},
"hogql_query": expected_hogql_query,
}

assert batch_export.schema == expected_schema
Expand Down
92 changes: 92 additions & 0 deletions posthog/api/test/batch_exports/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,95 @@ def test_can_patch_config_with_invalid_old_values(client: HttpClient, interval):
args = json.loads(decoded_payload[0].data)
assert args["bucket_name"] == "my-new-production-s3-bucket"
assert args.get("invalid_key", None) is None


def test_can_patch_hogql_query(client: HttpClient):
"""Test we can patch a schema with a HogQL query."""
temporal = sync_connect()

destination_data = {
"type": "S3",
"config": {
"bucket_name": "my-production-s3-bucket",
"region": "us-east-1",
"prefix": "posthog-events/",
"aws_access_key_id": "abc123",
"aws_secret_access_key": "secret",
},
}

batch_export_data = {
"name": "my-production-s3-bucket-destination",
"destination": destination_data,
"interval": "hour",
}

organization = create_organization("Test Org")
team = create_team(organization)
user = create_user("[email protected]", "Test User", organization)
client.force_login(user)

with start_test_worker(temporal):
batch_export = create_batch_export_ok(
client,
team.pk,
batch_export_data,
)
old_schedule = describe_schedule(temporal, batch_export["id"])

new_batch_export_data = {
"name": "my-production-s3-bucket-destination",
"hogql_query": "select toString(uuid) as uuid, 'test' as test, toInt(1+1) as n from events",
}

response = patch_batch_export(client, team.pk, batch_export["id"], new_batch_export_data)
assert response.status_code == status.HTTP_200_OK, response.json()

batch_export = get_batch_export_ok(client, team.pk, batch_export["id"])
assert batch_export["interval"] == "hour"
assert batch_export["destination"]["config"]["bucket_name"] == "my-production-s3-bucket"
assert batch_export["schema"] == {
"fields": [
{
"alias": "uuid",
"expression": "toString(events.uuid)",
},
{
"alias": "test",
"expression": "%(hogql_val_0)s",
},
{
"alias": "n",
"expression": "toInt64OrNull(plus(1, 1))",
},
],
"values": {"hogql_val_0": "test"},
"hogql_query": "SELECT toString(uuid) AS uuid, 'test' AS test, toInt(plus(1, 1)) AS n FROM events",
}

# validate the underlying temporal schedule has been updated
codec = EncryptionCodec(settings=settings)
new_schedule = describe_schedule(temporal, batch_export["id"])
assert old_schedule.schedule.spec.intervals[0].every == new_schedule.schedule.spec.intervals[0].every
decoded_payload = async_to_sync(codec.decode)(new_schedule.schedule.action.args)
args = json.loads(decoded_payload[0].data)
assert args["bucket_name"] == "my-production-s3-bucket"
assert args["interval"] == "hour"
assert args["batch_export_schema"] == {
"fields": [
{
"alias": "uuid",
"expression": "toString(events.uuid)",
},
{
"alias": "test",
"expression": "%(hogql_val_0)s",
},
{
"alias": "n",
"expression": "toInt64OrNull(plus(1, 1))",
},
],
"values": {"hogql_val_0": "test"},
"hogql_query": "SELECT toString(uuid) AS uuid, 'test' AS test, toInt(plus(1, 1)) AS n FROM events",
}
102 changes: 58 additions & 44 deletions posthog/batch_exports/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from posthog.batch_exports.service import (
BatchExportIdError,
BatchExportSchema,
BatchExportServiceError,
BatchExportServiceRPCError,
BatchExportServiceScheduleNotFound,
Expand Down Expand Up @@ -165,7 +166,19 @@ def to_internal_value(self, data: str) -> ast.SelectQuery | ast.SelectUnionQuery
except Exception:
raise serializers.ValidationError("Failed to parse query")

return parsed_query
try:
prepared_select_query: ast.SelectQuery = cast(
ast.SelectQuery,
prepare_ast_for_printing(
parsed_query,
context=HogQLContext(team_id=self.context["team_id"], enable_select_queries=True),
dialect="hogql",
),
)
except errors.ResolverException:
raise serializers.ValidationError(f"Invalid HogQL query")

return prepared_select_query


class BatchExportsField(TypedDict):
Expand All @@ -176,6 +189,7 @@ class BatchExportsField(TypedDict):
class BatchExportsSchema(TypedDict):
fields: list[BatchExportsField]
values: dict[str, str]
hogql_query: str


class BatchExportSerializer(serializers.ModelSerializer):
Expand All @@ -202,14 +216,9 @@ class Meta:
"end_at",
"latest_runs",
"hogql_query",
"schema",
]
read_only_fields = [
"id",
"team_id",
"created_at",
"last_updated_at",
"latest_runs",
]
read_only_fields = ["id", "team_id", "created_at", "last_updated_at", "latest_runs", "schema"]

def create(self, validated_data: dict) -> BatchExport:
"""Create a BatchExport."""
Expand All @@ -233,43 +242,9 @@ def create(self, validated_data: dict) -> BatchExport:
):
raise PermissionDenied("Higher frequency exports are not enabled for this team.")

hogql_query = None
if hogql_query := validated_data.pop("hogql_query", None):
context = HogQLContext(
team_id=team_id,
enable_select_queries=True,
)

try:
prepared_select_query: ast.SelectQuery = cast(
ast.SelectQuery,
prepare_ast_for_printing(hogql_query, context=context, dialect="clickhouse"),
)
except errors.ResolverException:
raise serializers.ValidationError(f"Invalid HogQL query")

batch_export_schema: BatchExportsSchema = {
"fields": [],
"values": {},
}
for field in prepared_select_query.select:
expression = print_prepared_ast(
field.expr, # type: ignore
context=context,
dialect="clickhouse",
)

if isinstance(field, ast.Alias):
alias = field.alias
else:
alias = expression

batch_export_field: BatchExportsField = {
"expression": expression,
"alias": alias,
}
batch_export_schema["fields"].append(batch_export_field)

batch_export_schema["values"] = context.values
batch_export_schema = self.serialize_hogql_query_to_batch_export_schema(hogql_query)
validated_data["schema"] = batch_export_schema

destination = BatchExportDestination(**destination_data)
Expand All @@ -282,6 +257,41 @@ def create(self, validated_data: dict) -> BatchExport:

return batch_export

def serialize_hogql_query_to_batch_export_schema(self, hogql_query: ast.SelectQuery) -> BatchExportSchema:
"""Return a batch export schema from a HogQL query ast."""
context = HogQLContext(
team_id=self.context["team_id"],
enable_select_queries=True,
limit_top_select=False,
)

batch_export_schema: BatchExportsSchema = {
"fields": [],
"values": {},
"hogql_query": print_prepared_ast(hogql_query, context=context, dialect="hogql"),
}
for field in hogql_query.select:
expression = print_prepared_ast(
field.expr, # type: ignore
context=context,
dialect="clickhouse",
)

if isinstance(field, ast.Alias):
alias = field.alias
else:
alias = expression

batch_export_field: BatchExportsField = {
"expression": expression,
"alias": alias,
}
batch_export_schema["fields"].append(batch_export_field)

batch_export_schema["values"] = context.values

return batch_export_schema

def validate_hogql_query(self, hogql_query: ast.SelectQuery | ast.SelectUnionQuery) -> ast.SelectQuery:
"""Validate a HogQLQuery being used for batch exports.
Expand Down Expand Up @@ -322,6 +332,10 @@ def update(self, batch_export: BatchExport, validated_data: dict) -> BatchExport
**destination_data.get("config", {}),
}

if hogql_query := validated_data.pop("hogql_query", None):
batch_export_schema = self.serialize_hogql_query_to_batch_export_schema(hogql_query)
validated_data["schema"] = batch_export_schema

batch_export.destination.save()
batch_export = super().update(batch_export, validated_data)

Expand Down

0 comments on commit c6cf384

Please sign in to comment.