diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index 7e597590ff987..efd10468fc77b 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -261,6 +261,15 @@ async def execute_async_query( results = cursor.fetchall() return results + async def aremove_internal_stage_files(self, table_name: str, table_stage_prefix: str) -> None: + """Asynchronously remove files from internal table stage. + + Arguments: + table_name: The name of the table whose internal stage to clear. + table_stage_prefix: Prefix to path of internal stage files. + """ + await self.execute_async_query(f"""REMOVE '@%"{table_name}"/{table_stage_prefix}'""") + async def acreate_table(self, table_name: str, fields: list[SnowflakeField]) -> None: """Asynchronously create the table if it doesn't exist. @@ -297,24 +306,30 @@ async def adelete_table( async def managed_table( self, table_name: str, + table_stage_prefix: str, fields: list[SnowflakeField], not_found_ok: bool = True, delete: bool = True, create: bool = True, ) -> collections.abc.AsyncGenerator[str, None]: """Manage a table in Snowflake by ensure it exists while in context.""" - if create: + if create is True: await self.acreate_table(table_name, fields) + else: + await self.aremove_internal_stage_files(table_name, table_stage_prefix) try: yield table_name finally: if delete is True: await self.adelete_table(table_name, not_found_ok) + else: + await self.aremove_internal_stage_files(table_name, table_stage_prefix) async def put_file_to_snowflake_table( self, file: BatchExportTemporaryFile, + table_stage_prefix: str, table_name: str, file_no: int, ): @@ -342,7 +357,9 @@ async def put_file_to_snowflake_table( # We comply with the file-like interface of io.IOBase. # So we ask mypy to be nice with us. reader = io.BufferedReader(file) # type: ignore - query = f'PUT file://{file.name}_{file_no}.jsonl @%"{table_name}"' + query = f""" + PUT file://{file.name}_{file_no}.jsonl '@%"{table_name}"/{table_stage_prefix}' + """ with self.connection.cursor() as cursor: cursor = self.connection.cursor() @@ -366,6 +383,7 @@ async def put_file_to_snowflake_table( async def copy_loaded_files_to_snowflake_table( self, table_name: str, + table_stage_prefix: str, ) -> None: """Execute a COPY query in Snowflake to load any files PUT into the table. @@ -377,6 +395,7 @@ async def copy_loaded_files_to_snowflake_table( """ query = f""" COPY INTO "{table_name}" + FROM '@%"{table_name}"/{table_stage_prefix}' FILE_FORMAT = (TYPE = 'JSON') MATCH_BY_COLUMN_NAME = CASE_SENSITIVE PURGE = TRUE @@ -631,9 +650,11 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor async with SnowflakeClient.from_inputs(inputs).connect() as snow_client: async with ( - snow_client.managed_table(inputs.table_name, table_fields, delete=False) as snow_table, snow_client.managed_table( - stagle_table_name, table_fields, create=requires_merge, delete=requires_merge + inputs.table_name, data_interval_end_str, table_fields, delete=False + ) as snow_table, + snow_client.managed_table( + stagle_table_name, data_interval_end_str, table_fields, create=requires_merge, delete=requires_merge ) as snow_stage_table, ): record_columns = [field[0] for field in table_fields] @@ -660,7 +681,9 @@ async def flush_to_snowflake( table = snow_stage_table if requires_merge else snow_table - await snow_client.put_file_to_snowflake_table(local_results_file, table, flush_counter) + await snow_client.put_file_to_snowflake_table( + local_results_file, data_interval_end_str, table, flush_counter + ) rows_exported.add(records_since_last_flush) bytes_exported.add(bytes_since_last_flush) @@ -678,7 +701,7 @@ async def flush_to_snowflake( await writer.write_record_batch(record_batch) await snow_client.copy_loaded_files_to_snowflake_table( - snow_stage_table if requires_merge else snow_table + snow_stage_table if requires_merge else snow_table, data_interval_end_str ) if requires_merge: merge_key = ( diff --git a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py index aa7626db3e4ea..8c6d944fd3394 100644 --- a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py @@ -5,6 +5,7 @@ import operator import os import re +import tempfile import unittest.mock import uuid from collections import deque @@ -394,6 +395,7 @@ async def test_snowflake_export_workflow_exports_events( count. """ data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + data_interval_end_str = data_interval_end.strftime("%Y-%m-%d_%H-%M-%S") data_interval_start = data_interval_end - snowflake_batch_export.interval_time_delta await generate_test_events_in_clickhouse( @@ -451,12 +453,12 @@ async def test_snowflake_export_workflow_exports_events( execute_calls = [] for cursor in fake_conn._cursors: for call in cursor._execute_calls: - execute_calls.append(call["query"]) + execute_calls.append(call["query"].strip()) execute_async_calls = [] for cursor in fake_conn._cursors: for call in cursor._execute_async_calls: - execute_async_calls.append(call["query"]) + execute_async_calls.append(call["query"].strip()) assert execute_async_calls[0:3] == [ f'USE DATABASE "{database}"', @@ -467,8 +469,9 @@ async def test_snowflake_export_workflow_exports_events( assert all(query.startswith("PUT") for query in execute_calls[0:9]) assert all(f"_{n}.jsonl" in query for n, query in enumerate(execute_calls[0:9])) - assert execute_async_calls[3].strip().startswith(f'CREATE TABLE IF NOT EXISTS "{table_name}"') - assert execute_async_calls[4].strip().startswith(f'COPY INTO "{table_name}"') + assert execute_async_calls[3].startswith(f'CREATE TABLE IF NOT EXISTS "{table_name}"') + assert execute_async_calls[4].startswith(f"""REMOVE '@%"{table_name}"/{data_interval_end_str}'""") + assert execute_async_calls[5].startswith(f'COPY INTO "{table_name}"') runs = await afetch_batch_export_runs(batch_export_id=snowflake_batch_export.id) assert len(runs) == 1 @@ -1141,6 +1144,104 @@ async def test_insert_into_snowflake_activity_merges_data_in_follow_up_runs( ) +@pytest.fixture +def garbage_jsonl_file(): + """Manage a JSON file with garbage data.""" + with tempfile.NamedTemporaryFile("w+b", suffix=".jsonl", prefix="garbage_") as garbage_jsonl_file: + garbage_jsonl_file.write(b'{"team_id": totally not an integer}\n') + garbage_jsonl_file.seek(0) + + yield garbage_jsonl_file.name + + +@SKIP_IF_MISSING_REQUIRED_ENV_VARS +async def test_insert_into_snowflake_activity_removes_internal_stage_files( + clickhouse_client, + activity_environment, + snowflake_cursor, + snowflake_config, + generate_test_data, + data_interval_start, + data_interval_end, + ateam, + garbage_jsonl_file, +): + """Test that the `insert_into_snowflake_activity` removes internal stage files. + + This test requires some setup steps: + 1. We do a first run of the activity to create the export table. Since we + haven't added any garbage, this should work normally. + 2. Truncate the table to avoid duplicate data once we re-run the activity. + 3. PUT a file with garbage data into the table internal stage. + + Once we run the activity a second time, it should first clear up the garbage + file and not fail the COPY. After this second execution is done, and besides + checking this second run worked and exported data, we also check that no files + are present in the table's internal stage. + """ + model = BatchExportModel(name="events", schema=None) + + table_name = f"test_insert_activity_table_remove_{ateam.pk}" + + insert_inputs = SnowflakeInsertInputs( + team_id=ateam.pk, + table_name=table_name, + data_interval_start=data_interval_start.isoformat(), + data_interval_end=data_interval_end.isoformat(), + batch_export_model=model, + **snowflake_config, + ) + + await activity_environment.run(insert_into_snowflake_activity, insert_inputs) + + await assert_clickhouse_records_in_snowflake( + snowflake_cursor=snowflake_cursor, + clickhouse_client=clickhouse_client, + table_name=table_name, + team_id=ateam.pk, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + batch_export_model=model, + sort_key="event", + ) + + snowflake_cursor.execute(f'TRUNCATE TABLE "{table_name}"') + + data_interval_end_str = data_interval_end.strftime("%Y-%m-%d_%H-%M-%S") + + put_query = f""" + PUT file://{garbage_jsonl_file} '@%"{table_name}"/{data_interval_end_str}' + """ + snowflake_cursor.execute(put_query) + + list_query = f""" + LIST '@%"{table_name}"' + """ + snowflake_cursor.execute(list_query) + rows = snowflake_cursor.fetchall() + columns = {index: metadata.name for index, metadata in enumerate(snowflake_cursor.description)} + stage_files = [{columns[index]: row[index] for index in columns.keys()} for row in rows] + assert len(stage_files) == 1 + assert stage_files[0]["name"] == f"{data_interval_end_str}/{os.path.basename(garbage_jsonl_file)}.gz" + + await activity_environment.run(insert_into_snowflake_activity, insert_inputs) + + await assert_clickhouse_records_in_snowflake( + snowflake_cursor=snowflake_cursor, + clickhouse_client=clickhouse_client, + table_name=table_name, + team_id=ateam.pk, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + batch_export_model=model, + sort_key="event", + ) + + snowflake_cursor.execute(list_query) + rows = snowflake_cursor.fetchall() + assert len(rows) == 0 + + @SKIP_IF_MISSING_REQUIRED_ENV_VARS @pytest.mark.parametrize("interval", ["hour", "day"], indirect=True) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True)