Skip to content

Commit

Permalink
fix(batch-exports): Use prefix in snowflake internal stage (#25847)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
tomasfarias and github-actions[bot] authored Oct 28, 2024
1 parent a0612b9 commit 2b389b3
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 10 deletions.
35 changes: 29 additions & 6 deletions posthog/temporal/batch_exports/snowflake_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ async def execute_async_query(
results = cursor.fetchall()
return results

async def aremove_internal_stage_files(self, table_name: str, table_stage_prefix: str) -> None:
"""Asynchronously remove files from internal table stage.
Arguments:
table_name: The name of the table whose internal stage to clear.
table_stage_prefix: Prefix to path of internal stage files.
"""
await self.execute_async_query(f"""REMOVE '@%"{table_name}"/{table_stage_prefix}'""")

async def acreate_table(self, table_name: str, fields: list[SnowflakeField]) -> None:
"""Asynchronously create the table if it doesn't exist.
Expand Down Expand Up @@ -297,24 +306,30 @@ async def adelete_table(
async def managed_table(
self,
table_name: str,
table_stage_prefix: str,
fields: list[SnowflakeField],
not_found_ok: bool = True,
delete: bool = True,
create: bool = True,
) -> collections.abc.AsyncGenerator[str, None]:
"""Manage a table in Snowflake by ensure it exists while in context."""
if create:
if create is True:
await self.acreate_table(table_name, fields)
else:
await self.aremove_internal_stage_files(table_name, table_stage_prefix)

try:
yield table_name
finally:
if delete is True:
await self.adelete_table(table_name, not_found_ok)
else:
await self.aremove_internal_stage_files(table_name, table_stage_prefix)

async def put_file_to_snowflake_table(
self,
file: BatchExportTemporaryFile,
table_stage_prefix: str,
table_name: str,
file_no: int,
):
Expand Down Expand Up @@ -342,7 +357,9 @@ async def put_file_to_snowflake_table(
# We comply with the file-like interface of io.IOBase.
# So we ask mypy to be nice with us.
reader = io.BufferedReader(file) # type: ignore
query = f'PUT file://{file.name}_{file_no}.jsonl @%"{table_name}"'
query = f"""
PUT file://{file.name}_{file_no}.jsonl '@%"{table_name}"/{table_stage_prefix}'
"""

with self.connection.cursor() as cursor:
cursor = self.connection.cursor()
Expand All @@ -366,6 +383,7 @@ async def put_file_to_snowflake_table(
async def copy_loaded_files_to_snowflake_table(
self,
table_name: str,
table_stage_prefix: str,
) -> None:
"""Execute a COPY query in Snowflake to load any files PUT into the table.
Expand All @@ -377,6 +395,7 @@ async def copy_loaded_files_to_snowflake_table(
"""
query = f"""
COPY INTO "{table_name}"
FROM '@%"{table_name}"/{table_stage_prefix}'
FILE_FORMAT = (TYPE = 'JSON')
MATCH_BY_COLUMN_NAME = CASE_SENSITIVE
PURGE = TRUE
Expand Down Expand Up @@ -631,9 +650,11 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor

async with SnowflakeClient.from_inputs(inputs).connect() as snow_client:
async with (
snow_client.managed_table(inputs.table_name, table_fields, delete=False) as snow_table,
snow_client.managed_table(
stagle_table_name, table_fields, create=requires_merge, delete=requires_merge
inputs.table_name, data_interval_end_str, table_fields, delete=False
) as snow_table,
snow_client.managed_table(
stagle_table_name, data_interval_end_str, table_fields, create=requires_merge, delete=requires_merge
) as snow_stage_table,
):
record_columns = [field[0] for field in table_fields]
Expand All @@ -660,7 +681,9 @@ async def flush_to_snowflake(

table = snow_stage_table if requires_merge else snow_table

await snow_client.put_file_to_snowflake_table(local_results_file, table, flush_counter)
await snow_client.put_file_to_snowflake_table(
local_results_file, data_interval_end_str, table, flush_counter
)
rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)

Expand All @@ -678,7 +701,7 @@ async def flush_to_snowflake(
await writer.write_record_batch(record_batch)

await snow_client.copy_loaded_files_to_snowflake_table(
snow_stage_table if requires_merge else snow_table
snow_stage_table if requires_merge else snow_table, data_interval_end_str
)
if requires_merge:
merge_key = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import operator
import os
import re
import tempfile
import unittest.mock
import uuid
from collections import deque
Expand Down Expand Up @@ -394,6 +395,7 @@ async def test_snowflake_export_workflow_exports_events(
count.
"""
data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00")
data_interval_end_str = data_interval_end.strftime("%Y-%m-%d_%H-%M-%S")
data_interval_start = data_interval_end - snowflake_batch_export.interval_time_delta

await generate_test_events_in_clickhouse(
Expand Down Expand Up @@ -451,12 +453,12 @@ async def test_snowflake_export_workflow_exports_events(
execute_calls = []
for cursor in fake_conn._cursors:
for call in cursor._execute_calls:
execute_calls.append(call["query"])
execute_calls.append(call["query"].strip())

execute_async_calls = []
for cursor in fake_conn._cursors:
for call in cursor._execute_async_calls:
execute_async_calls.append(call["query"])
execute_async_calls.append(call["query"].strip())

assert execute_async_calls[0:3] == [
f'USE DATABASE "{database}"',
Expand All @@ -467,8 +469,9 @@ async def test_snowflake_export_workflow_exports_events(
assert all(query.startswith("PUT") for query in execute_calls[0:9])
assert all(f"_{n}.jsonl" in query for n, query in enumerate(execute_calls[0:9]))

assert execute_async_calls[3].strip().startswith(f'CREATE TABLE IF NOT EXISTS "{table_name}"')
assert execute_async_calls[4].strip().startswith(f'COPY INTO "{table_name}"')
assert execute_async_calls[3].startswith(f'CREATE TABLE IF NOT EXISTS "{table_name}"')
assert execute_async_calls[4].startswith(f"""REMOVE '@%"{table_name}"/{data_interval_end_str}'""")
assert execute_async_calls[5].startswith(f'COPY INTO "{table_name}"')

runs = await afetch_batch_export_runs(batch_export_id=snowflake_batch_export.id)
assert len(runs) == 1
Expand Down Expand Up @@ -1141,6 +1144,104 @@ async def test_insert_into_snowflake_activity_merges_data_in_follow_up_runs(
)


@pytest.fixture
def garbage_jsonl_file():
"""Manage a JSON file with garbage data."""
with tempfile.NamedTemporaryFile("w+b", suffix=".jsonl", prefix="garbage_") as garbage_jsonl_file:
garbage_jsonl_file.write(b'{"team_id": totally not an integer}\n')
garbage_jsonl_file.seek(0)

yield garbage_jsonl_file.name


@SKIP_IF_MISSING_REQUIRED_ENV_VARS
async def test_insert_into_snowflake_activity_removes_internal_stage_files(
clickhouse_client,
activity_environment,
snowflake_cursor,
snowflake_config,
generate_test_data,
data_interval_start,
data_interval_end,
ateam,
garbage_jsonl_file,
):
"""Test that the `insert_into_snowflake_activity` removes internal stage files.
This test requires some setup steps:
1. We do a first run of the activity to create the export table. Since we
haven't added any garbage, this should work normally.
2. Truncate the table to avoid duplicate data once we re-run the activity.
3. PUT a file with garbage data into the table internal stage.
Once we run the activity a second time, it should first clear up the garbage
file and not fail the COPY. After this second execution is done, and besides
checking this second run worked and exported data, we also check that no files
are present in the table's internal stage.
"""
model = BatchExportModel(name="events", schema=None)

table_name = f"test_insert_activity_table_remove_{ateam.pk}"

insert_inputs = SnowflakeInsertInputs(
team_id=ateam.pk,
table_name=table_name,
data_interval_start=data_interval_start.isoformat(),
data_interval_end=data_interval_end.isoformat(),
batch_export_model=model,
**snowflake_config,
)

await activity_environment.run(insert_into_snowflake_activity, insert_inputs)

await assert_clickhouse_records_in_snowflake(
snowflake_cursor=snowflake_cursor,
clickhouse_client=clickhouse_client,
table_name=table_name,
team_id=ateam.pk,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
batch_export_model=model,
sort_key="event",
)

snowflake_cursor.execute(f'TRUNCATE TABLE "{table_name}"')

data_interval_end_str = data_interval_end.strftime("%Y-%m-%d_%H-%M-%S")

put_query = f"""
PUT file://{garbage_jsonl_file} '@%"{table_name}"/{data_interval_end_str}'
"""
snowflake_cursor.execute(put_query)

list_query = f"""
LIST '@%"{table_name}"'
"""
snowflake_cursor.execute(list_query)
rows = snowflake_cursor.fetchall()
columns = {index: metadata.name for index, metadata in enumerate(snowflake_cursor.description)}
stage_files = [{columns[index]: row[index] for index in columns.keys()} for row in rows]
assert len(stage_files) == 1
assert stage_files[0]["name"] == f"{data_interval_end_str}/{os.path.basename(garbage_jsonl_file)}.gz"

await activity_environment.run(insert_into_snowflake_activity, insert_inputs)

await assert_clickhouse_records_in_snowflake(
snowflake_cursor=snowflake_cursor,
clickhouse_client=clickhouse_client,
table_name=table_name,
team_id=ateam.pk,
data_interval_start=data_interval_start,
data_interval_end=data_interval_end,
batch_export_model=model,
sort_key="event",
)

snowflake_cursor.execute(list_query)
rows = snowflake_cursor.fetchall()
assert len(rows) == 0


@SKIP_IF_MISSING_REQUIRED_ENV_VARS
@pytest.mark.parametrize("interval", ["hour", "day"], indirect=True)
@pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True)
Expand Down

0 comments on commit 2b389b3

Please sign in to comment.