Skip to content

Commit

Permalink
feat(data-warehouse): Use delta format for DLT (#23584)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Tom Owers <[email protected]>
Co-authored-by: Tom Owers <[email protected]>
  • Loading branch information
4 people authored Jul 15, 2024
1 parent 4c61a11 commit 2bfd439
Show file tree
Hide file tree
Showing 19 changed files with 320 additions and 169 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ plugin-transpiler/dist
*-esbuild-meta.json
*-esbuild-bundle-visualization.html
.dlt
*.db
*.db
2 changes: 1 addition & 1 deletion latest_migrations.manifest
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ contenttypes: 0002_remove_content_type_name
ee: 0016_rolemembership_organization_member
otp_static: 0002_throttling
otp_totp: 0002_auto_20190420_0723
posthog: 0440_organizationinvite_private_project_access
posthog: 0441_alter_datawarehousetable_format
sessions: 0001_initial
social_django: 0010_uid_db_index
two_factor: 0007_auto_20201201_1019
93 changes: 75 additions & 18 deletions posthog/hogql/database/s3_table.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,76 @@
from typing import Optional

from posthog.clickhouse.client.escape import substitute_params
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.models import FunctionCallTable
from posthog.hogql.escape_sql import escape_hogql_identifier


def build_function_call(
url: str,
format: str,
access_key: Optional[str] = None,
access_secret: Optional[str] = None,
structure: Optional[str] = None,
context: Optional[HogQLContext] = None,
) -> str:
raw_params: dict[str, str] = {}

def add_param(value: str, is_sensitive: bool = True) -> str:
if context is not None:
if is_sensitive:
return context.add_sensitive_value(value)
return context.add_value(value)

param_name = f"value_{len(raw_params.items())}"
raw_params[param_name] = value
return f"%({param_name})s"

def return_expr(expr: str) -> str:
if context is not None:
return f"{expr})"

return f"{substitute_params(expr, raw_params)})"

if format == "Delta":
escaped_url = add_param(url)
if structure:
escaped_structure = add_param(structure, False)

expr = f"deltaLake({escaped_url}"

if access_key and access_secret:
escaped_access_key = add_param(access_key)
escaped_access_secret = add_param(access_secret)

expr += f", {escaped_access_key}, {escaped_access_secret}"

if structure:
expr += f", {escaped_structure}"

return return_expr(expr)

escaped_url = add_param(url)
escaped_format = add_param(format, False)
if structure:
escaped_structure = add_param(structure, False)

expr = f"s3({escaped_url}"

if access_key and access_secret:
escaped_access_key = add_param(access_key)
escaped_access_secret = add_param(access_secret)

expr += f", {escaped_access_key}, {escaped_access_secret}"

expr += f", {escaped_format}"

if structure:
expr += f", {escaped_structure}"

return return_expr(expr)


class S3Table(FunctionCallTable):
url: str
format: str = "CSVWithNames"
Expand All @@ -15,21 +82,11 @@ def to_printed_hogql(self):
return escape_hogql_identifier(self.name)

def to_printed_clickhouse(self, context):
escaped_url = context.add_sensitive_value(self.url)
escaped_format = context.add_value(self.format)
escaped_structure = context.add_value(self.structure)

expr = f"s3({escaped_url}"

if self.access_key and self.access_secret:
escaped_access_key = context.add_sensitive_value(self.access_key)
escaped_access_secret = context.add_sensitive_value(self.access_secret)

expr += f", {escaped_access_key}, {escaped_access_secret}"

expr += f", {escaped_format}"

if self.structure:
expr += f", {escaped_structure}"

return f"{expr})"
return build_function_call(
url=self.url,
format=self.format,
access_key=self.access_key,
access_secret=self.access_secret,
structure=self.structure,
context=context,
)
26 changes: 24 additions & 2 deletions posthog/hogql/database/test/test_s3_table.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from posthog.hogql.constants import MAX_SELECT_RETURNED_ROWS
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.database import create_hogql_database
from posthog.hogql.database.s3_table import build_function_call
from posthog.hogql.parser import parse_select
from posthog.hogql.printer import print_ast
from posthog.hogql.query import create_default_modifiers_for_team
from posthog.test.base import BaseTest
from posthog.hogql.database.test.tables import create_aapl_stock_s3_table
from posthog.hogql.errors import ExposedHogQLError
from posthog.warehouse.models.table import DataWarehouseTable


class TestS3Table(BaseTest):
Expand Down Expand Up @@ -73,7 +75,7 @@ def test_s3_table_select_join(self):

self.assertEqual(
clickhouse,
"SELECT aapl_stock.High AS High, aapl_stock.Low AS Low FROM (SELECT * FROM s3(%(hogql_val_0_sensitive)s, %(hogql_val_1)s)) AS aapl_stock JOIN (SELECT * FROM s3(%(hogql_val_3_sensitive)s, %(hogql_val_4)s)) AS aapl_stock_2 ON equals(aapl_stock.High, aapl_stock_2.High) LIMIT 10",
"SELECT aapl_stock.High AS High, aapl_stock.Low AS Low FROM (SELECT * FROM s3(%(hogql_val_0_sensitive)s, %(hogql_val_1)s)) AS aapl_stock JOIN (SELECT * FROM s3(%(hogql_val_2_sensitive)s, %(hogql_val_3)s)) AS aapl_stock_2 ON equals(aapl_stock.High, aapl_stock_2.High) LIMIT 10",
)

def test_s3_table_select_join_with_alias(self):
Expand All @@ -96,7 +98,7 @@ def test_s3_table_select_join_with_alias(self):
# Alias will completely override table name to prevent ambiguous table names that can be shared if the same table is joinedfrom multiple times
self.assertEqual(
clickhouse,
"SELECT a.High AS High, a.Low AS Low FROM (SELECT * FROM s3(%(hogql_val_0_sensitive)s, %(hogql_val_1)s)) AS a JOIN (SELECT * FROM s3(%(hogql_val_3_sensitive)s, %(hogql_val_4)s)) AS b ON equals(a.High, b.High) LIMIT 10",
"SELECT a.High AS High, a.Low AS Low FROM (SELECT * FROM s3(%(hogql_val_0_sensitive)s, %(hogql_val_1)s)) AS a JOIN (SELECT * FROM s3(%(hogql_val_2_sensitive)s, %(hogql_val_3)s)) AS b ON equals(a.High, b.High) LIMIT 10",
)

def test_s3_table_select_and_non_s3_join(self):
Expand Down Expand Up @@ -202,3 +204,23 @@ def test_s3_table_select_in(self):
clickhouse,
f"SELECT events.uuid AS uuid, events.event AS event FROM events WHERE and(equals(events.team_id, {self.team.pk}), ifNull(globalIn(events.event, (SELECT aapl_stock.Date AS Date FROM s3(%(hogql_val_0_sensitive)s, %(hogql_val_1)s) AS aapl_stock)), 0)) LIMIT {MAX_SELECT_RETURNED_ROWS}",
)

def test_s3_build_function_call_without_context(self):
res = build_function_call("http://url.com", DataWarehouseTable.TableFormat.Parquet, "key", "secret", None, None)
assert res == "s3('http://url.com', 'key', 'secret', 'Parquet')"

def test_s3_build_function_call_without_context_with_structure(self):
res = build_function_call(
"http://url.com", DataWarehouseTable.TableFormat.Parquet, "key", "secret", "some structure", None
)
assert res == "s3('http://url.com', 'key', 'secret', 'Parquet', 'some structure')"

def test_s3_build_function_call_without_context_and_delta_format(self):
res = build_function_call("http://url.com", DataWarehouseTable.TableFormat.Delta, "key", "secret", None, None)
assert res == "deltaLake('http://url.com', 'key', 'secret')"

def test_s3_build_function_call_without_context_and_delta_format_and_with_structure(self):
res = build_function_call(
"http://url.com", DataWarehouseTable.TableFormat.Delta, "key", "secret", "some structure", None
)
assert res == "deltaLake('http://url.com', 'key', 'secret', 'some structure')"
26 changes: 26 additions & 0 deletions posthog/migrations/0441_alter_datawarehousetable_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 4.2.11 on 2024-07-10 10:09

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("posthog", "0440_organizationinvite_private_project_access"),
]

operations = [
migrations.AlterField(
model_name="datawarehousetable",
name="format",
field=models.CharField(
choices=[
("CSV", "CSV"),
("CSVWithNames", "CSVWithNames"),
("Parquet", "Parquet"),
("JSONEachRow", "JSON"),
("Delta", "Delta"),
],
max_length=128,
),
),
]
1 change: 1 addition & 0 deletions posthog/settings/data_warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# for DLT
BUCKET_URL = os.getenv("BUCKET_URL", None)
AIRBYTE_BUCKET_NAME = os.getenv("AIRBYTE_BUCKET_NAME", None)
BUCKET = "test-pipeline"

HUBSPOT_APP_CLIENT_ID = os.getenv("HUBSPOT_APP_CLIENT_ID", None)
HUBSPOT_APP_CLIENT_SECRET = os.getenv("HUBSPOT_APP_CLIENT_SECRET", None)
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def hubspot(
crm_objects,
name=endpoint,
write_disposition="replace",
table_format="delta",
)(
object_type=OBJECT_TYPE_SINGULAR[endpoint],
api_key=api_key,
Expand Down
38 changes: 27 additions & 11 deletions posthog/temporal/data_imports/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from posthog.settings.base_variables import TEST
from structlog.typing import FilteringBoundLogger
from dlt.sources import DltSource
from deltalake.exceptions import DeltaError
from collections import Counter

from posthog.warehouse.data_load.validate_schema import validate_schema_and_update_table
Expand Down Expand Up @@ -63,12 +64,16 @@ def _get_destination(self):
"aws_access_key_id": settings.AIRBYTE_BUCKET_KEY,
"aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET,
"endpoint_url": settings.OBJECT_STORAGE_ENDPOINT,
"region_name": settings.AIRBYTE_BUCKET_REGION,
"AWS_ALLOW_HTTP": "true",
"AWS_S3_ALLOW_UNSAFE_RENAME": "true",
}
else:
credentials = {
"aws_access_key_id": settings.AIRBYTE_BUCKET_KEY,
"aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET,
"region_name": settings.AIRBYTE_BUCKET_REGION,
"AWS_S3_ALLOW_UNSAFE_RENAME": "true",
}

return dlt.destinations.filesystem(
Expand Down Expand Up @@ -103,11 +108,17 @@ def _run(self) -> dict[str, int]:
while counts:
self.logger.info(f"Running incremental (non-sql) pipeline, run ${pipeline_runs}")

pipeline.run(
self.source,
loader_file_format=self.loader_file_format,
refresh="drop_sources" if self.refresh_dlt and pipeline_runs == 0 else None,
)
try:
pipeline.run(
self.source,
loader_file_format=self.loader_file_format,
refresh="drop_sources" if self.refresh_dlt and pipeline_runs == 0 else None,
)
except PipelineStepFailed as e:
# Remove once DLT support writing empty Delta files
if isinstance(e.exception, DeltaError):
if e.exception.args[0] != "Generic error: No data source supplied to write command.":
raise

row_counts = pipeline.last_trace.last_normalize_info.row_counts
# Remove any DLT tables from the counts
Expand All @@ -126,12 +137,17 @@ def _run(self) -> dict[str, int]:
pipeline_runs = pipeline_runs + 1
else:
self.logger.info("Running standard pipeline")

pipeline.run(
self.source,
loader_file_format=self.loader_file_format,
refresh="drop_sources" if self.refresh_dlt else None,
)
try:
pipeline.run(
self.source,
loader_file_format=self.loader_file_format,
refresh="drop_sources" if self.refresh_dlt else None,
)
except PipelineStepFailed as e:
# Remove once DLT support writing empty Delta files
if isinstance(e.exception, DeltaError):
if e.exception.args[0] != "Generic error: No data source supplied to write command.":
raise
row_counts = pipeline.last_trace.last_normalize_info.row_counts
filtered_rows = dict(filter(lambda pair: not pair[0].startswith("_dlt"), row_counts.items()))
counts = Counter(filtered_rows)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def sql_database(
merge_key=get_primary_key(table),
write_disposition="merge" if incremental else "replace",
spec=SqlDatabaseTableConfiguration,
table_format="delta",
)(
engine=engine,
table=table,
Expand Down
7 changes: 7 additions & 0 deletions posthog/temporal/data_imports/pipelines/stripe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource:
# "type": "OPTIONAL_CONFIG",
},
},
"table_format": "delta",
},
"Charge": {
"name": "Charge",
Expand Down Expand Up @@ -64,6 +65,7 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource:
# "transfer_group": "OPTIONAL_CONFIG",
},
},
"table_format": "delta",
},
"Customer": {
"name": "Customer",
Expand Down Expand Up @@ -91,6 +93,7 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource:
# "test_clock": "OPTIONAL_CONFIG",
},
},
"table_format": "delta",
},
"Invoice": {
"name": "Invoice",
Expand Down Expand Up @@ -121,6 +124,7 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource:
# "subscription": "OPTIONAL_CONFIG",
},
},
"table_format": "delta",
},
"Price": {
"name": "Price",
Expand Down Expand Up @@ -152,6 +156,7 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource:
# "type": "OPTIONAL_CONFIG",
},
},
"table_format": "delta",
},
"Product": {
"name": "Product",
Expand Down Expand Up @@ -181,6 +186,7 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource:
# "url": "OPTIONAL_CONFIG",
},
},
"table_format": "delta",
},
"Subscription": {
"name": "Subscription",
Expand Down Expand Up @@ -213,6 +219,7 @@ def get_resource(name: str, is_incremental: bool) -> EndpointResource:
# "test_clock": "OPTIONAL_CONFIG",
},
},
"table_format": "delta",
},
}

Expand Down
Loading

0 comments on commit 2bfd439

Please sign in to comment.