From 6b047de7c9a8ad4c043a842d7444a0d487271fc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Mon, 6 Nov 2023 18:33:05 +0100 Subject: [PATCH] fix: Snowflake's execute_async doesn't support PUT --- .../workflows/snowflake_batch_export.py | 139 +++++++++++------- 1 file changed, 82 insertions(+), 57 deletions(-) diff --git a/posthog/temporal/workflows/snowflake_batch_export.py b/posthog/temporal/workflows/snowflake_batch_export.py index 010dd9c66c5e1..4428704706bc9 100644 --- a/posthog/temporal/workflows/snowflake_batch_export.py +++ b/posthog/temporal/workflows/snowflake_batch_export.py @@ -1,6 +1,8 @@ import asyncio import contextlib import datetime as dt +import functools +import io import json import typing from dataclasses import dataclass @@ -8,7 +10,6 @@ import snowflake.connector from django.conf import settings from snowflake.connector.connection import SnowflakeConnection -from snowflake.connector.cursor import SnowflakeCursor from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -69,6 +70,12 @@ class SnowflakeInsertInputs: include_events: list[str] | None = None +def use_namespace(connection: SnowflakeConnection, database: str, schema: str) -> None: + cursor = connection.cursor() + cursor.execute(f'USE DATABASE "{database}"') + cursor.execute(f'USE SCHEMA "{schema}"') + + @contextlib.contextmanager def snowflake_connection(inputs) -> typing.Generator[SnowflakeConnection, None, None]: with snowflake.connector.connect( @@ -80,12 +87,18 @@ def snowflake_connection(inputs) -> typing.Generator[SnowflakeConnection, None, schema=inputs.schema, role=inputs.role, ) as connection: + connection.cursor().execute("SET ABORT_DETACHED_QUERY = FALSE") + yield connection async def execute_async_query( - connection: SnowflakeConnection, query: str, parameters: dict | None = None, poll_interval: float = 1.0 -) -> SnowflakeCursor: + connection: SnowflakeConnection, + query: str, + parameters: dict | None = None, + file_stream=None, + poll_interval: float = 1.0, +) -> str: """Wrap Snowflake connector's polling API in a coroutine. This enables asynchronous execution of queries to release the event loop to execute other tasks @@ -99,35 +112,33 @@ async def execute_async_query( """ cursor = connection.cursor() - cursor.execute_async(query, parameters or {}) + # Snowflake docs incorrectly state that the 'params' argument is named 'parameters'. + result = cursor.execute_async(query, params=parameters, file_stream=file_stream) + query_id = result["queryId"] + query_status = None try: - query_id = cursor.sfqid + query_status = connection.get_query_status_throw_if_error(query_id) - while connection.is_still_running(connection.get_query_status_throw_if_error(query_id)): + while connection.is_still_running(query_status): + query_status = connection.get_query_status_throw_if_error(query_id) await asyncio.sleep(poll_interval) except snowflake.connector.ProgrammingError: - # TODO: logging + # TODO: logging? Other handling? raise - else: - cursor.get_results_from_sfqid(query_id) - - finally: - pass + return query_id - return cursor - -async def create_table_in_snowflake(conn: SnowflakeConnection, database: str, schema: str, table_name: str) -> None: +async def create_table_in_snowflake(connection: SnowflakeConnection, table_name: str) -> None: """Asynchronously create the table if it doesn't exist. Note that we use the same schema as the snowflake-plugin for backwards compatibility.""" await execute_async_query( - conn, + connection, f""" - CREATE TABLE IF NOT EXISTS "{database}"."{schema}"."{table_name}" ( + CREATE TABLE IF NOT EXISTS "{table_name}" ( "uuid" STRING, "event" STRING, "properties" VARIANT, @@ -149,33 +160,40 @@ async def put_file_to_snowflake_table( connection: SnowflakeConnection, file: BatchExportTemporaryFile, table_name: str, - database: str | None = None, - schema: str | None = None, + file_no: int, ): """Executes a PUT query using the provided cursor to the provided table_name. + Sadly, Snowflake's execute_async does not work with PUT statements. So, we pass the execute + call to run_in_executor: Since execute ends up boiling down to blocking IO (HTTP request), + the event loop should not be locked up. + + We add a file_no to the file_name when executing PUT as Snowflake will reject any files with the same + name. Since batch exports re-use the same file, our name does not change, but we don't want Snowflake + to reject or overwrite our new data. + Args: connection: A SnowflakeConnection object as produced by snowflake.connector.connect. file: The name of the local file to PUT. table_name: The name of the Snowflake table where to PUT the file. + file_no: An int to identify which file number this is. Raises: TypeError: If we don't get a tuple back from Snowflake (should never happen). SnowflakeFileNotUploadedError: If the upload status is not 'UPLOADED'. """ - file.flush() + file.rewind() - if database is not None and schema is not None: - namespace = f"{database}.{schema}." - else: - namespace = "" + # We comply with the file-like interface of io.IOBase. + # So we ask mypy to be nice with us. + reader = io.BufferedReader(file) # type: ignore + query = f'PUT file://{file.name}_{file_no}.jsonl @%"{table_name}"' + cursor = connection.cursor() - cursor = await execute_async_query( - connection, - f""" - PUT file://{file.name} @%{namespace}"{table_name}" - """, - ) + execute_put = functools.partial(cursor.execute, query, file_stream=reader) + + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, func=execute_put) result = cursor.fetchone() if not isinstance(result, tuple): @@ -190,23 +208,17 @@ async def put_file_to_snowflake_table( async def copy_loaded_files_to_snowflake_table( connection: SnowflakeConnection, table_name: str, - database: str | None = None, - schema: str | None = None, ): - if database is not None and schema is not None: - namespace = f"{database}.{schema}." - else: - namespace = "" + query = f""" + COPY INTO "{table_name}" + FILE_FORMAT = (TYPE = 'JSON') + MATCH_BY_COLUMN_NAME = CASE_SENSITIVE + PURGE = TRUE + """ + query_id = await execute_async_query(connection, query) - cursor = await execute_async_query( - connection, - f""" - COPY INTO {namespace}"{table_name}" - FILE_FORMAT = (TYPE = 'JSON') - MATCH_BY_COLUMN_NAME = CASE_SENSITIVE - PURGE = TRUE - """, - ) + cursor = connection.cursor() + cursor.get_results_from_sfqid(query_id) results = cursor.fetchall() for query_result in results: @@ -219,7 +231,7 @@ async def copy_loaded_files_to_snowflake_table( table_name, "NO STATUS", 0, - query_result[1] if len(query_result) == 1 else "NO ERROR MESSAGE", + query_result[0] if len(query_result) == 1 else "NO ERROR MESSAGE", ) _, status = query_result[0:2] @@ -271,7 +283,9 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs): logger.info("BatchExporting %s rows to Snowflake", count) with snowflake_connection(inputs) as connection: - await create_table_in_snowflake(connection, inputs.database, inputs.schema, inputs.table_name) + use_namespace(connection, inputs.database, inputs.schema) + + await create_table_in_snowflake(connection, inputs.table_name) results_iterator = get_results_iterator( client=client, @@ -284,6 +298,7 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs): result = None last_uploaded_file_timestamp = None + file_no = 0 async def worker_shutdown_handler(): """Handle the Worker shutting down by heart-beating our latest status.""" @@ -297,7 +312,20 @@ async def worker_shutdown_handler(): with BatchExportTemporaryFile() as local_results_file: for result in results_iterator: - local_results_file.write_records_to_jsonl([result]) + record = { + "uuid": result["uuid"], + "event": result["event"], + "properties": result["properties"], + "elements": result["elements"], + "people_set": result["set"], + "people_set_once": result["set_once"], + "distinct_id": result["distinct_id"], + "team_id": result["team_id"], + "ip": result["ip"], + "site_url": result["site_url"], + "timestamp": result["timestamp"], + } + local_results_file.write_records_to_jsonl([record]) if local_results_file.tell() > settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES: logger.info( @@ -307,12 +335,13 @@ async def worker_shutdown_handler(): ) await put_file_to_snowflake_table( - connection, local_results_file, inputs.table_name, inputs.database, inputs.schema + connection, local_results_file, inputs.table_name, file_no=file_no ) last_uploaded_file_timestamp = result["inserted_at"] - activity.heartbeat(last_uploaded_file_timestamp) + activity.heartbeat(last_uploaded_file_timestamp, file_no) + file_no += 1 local_results_file.reset() if local_results_file.tell() > 0 and result is not None: @@ -321,16 +350,14 @@ async def worker_shutdown_handler(): local_results_file.records_since_last_reset, local_results_file.bytes_since_last_reset, ) - last_uploaded_file_timestamp = result["inserted_at"] - await put_file_to_snowflake_table( - connection, local_results_file, inputs.table_name, inputs.database, inputs.schema + connection, local_results_file, inputs.table_name, file_no=file_no ) last_uploaded_file_timestamp = result["inserted_at"] - activity.heartbeat(last_uploaded_file_timestamp) + activity.heartbeat(last_uploaded_file_timestamp, file_no) - await copy_loaded_files_to_snowflake_table(connection, inputs.table_name, inputs.database, inputs.schema) + await copy_loaded_files_to_snowflake_table(connection, inputs.table_name) @workflow.defn(name="snowflake-export") @@ -411,6 +438,4 @@ async def run(self, inputs: SnowflakeBatchExportInputs): "ForbiddenError", ], update_inputs=update_inputs, - # Disable heartbeat timeout until we add heartbeat support. - heartbeat_timeout_seconds=None, )