Skip to content

Commit

Permalink
fix: Snowflake's execute_async doesn't support PUT
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Nov 6, 2023
1 parent 48a3046 commit 6b047de
Showing 1 changed file with 82 additions and 57 deletions.
139 changes: 82 additions & 57 deletions posthog/temporal/workflows/snowflake_batch_export.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import asyncio
import contextlib
import datetime as dt
import functools
import io
import json
import typing
from dataclasses import dataclass

import snowflake.connector
from django.conf import settings
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
from temporalio import activity, workflow
from temporalio.common import RetryPolicy

Expand Down Expand Up @@ -69,6 +70,12 @@ class SnowflakeInsertInputs:
include_events: list[str] | None = None


def use_namespace(connection: SnowflakeConnection, database: str, schema: str) -> None:
cursor = connection.cursor()
cursor.execute(f'USE DATABASE "{database}"')
cursor.execute(f'USE SCHEMA "{schema}"')


@contextlib.contextmanager
def snowflake_connection(inputs) -> typing.Generator[SnowflakeConnection, None, None]:
with snowflake.connector.connect(
Expand All @@ -80,12 +87,18 @@ def snowflake_connection(inputs) -> typing.Generator[SnowflakeConnection, None,
schema=inputs.schema,
role=inputs.role,
) as connection:
connection.cursor().execute("SET ABORT_DETACHED_QUERY = FALSE")

yield connection


async def execute_async_query(
connection: SnowflakeConnection, query: str, parameters: dict | None = None, poll_interval: float = 1.0
) -> SnowflakeCursor:
connection: SnowflakeConnection,
query: str,
parameters: dict | None = None,
file_stream=None,
poll_interval: float = 1.0,
) -> str:
"""Wrap Snowflake connector's polling API in a coroutine.
This enables asynchronous execution of queries to release the event loop to execute other tasks
Expand All @@ -99,35 +112,33 @@ async def execute_async_query(
"""
cursor = connection.cursor()

cursor.execute_async(query, parameters or {})
# Snowflake docs incorrectly state that the 'params' argument is named 'parameters'.
result = cursor.execute_async(query, params=parameters, file_stream=file_stream)
query_id = result["queryId"]
query_status = None

try:
query_id = cursor.sfqid
query_status = connection.get_query_status_throw_if_error(query_id)

while connection.is_still_running(connection.get_query_status_throw_if_error(query_id)):
while connection.is_still_running(query_status):
query_status = connection.get_query_status_throw_if_error(query_id)
await asyncio.sleep(poll_interval)

except snowflake.connector.ProgrammingError:
# TODO: logging
# TODO: logging? Other handling?
raise

else:
cursor.get_results_from_sfqid(query_id)

finally:
pass
return query_id

return cursor


async def create_table_in_snowflake(conn: SnowflakeConnection, database: str, schema: str, table_name: str) -> None:
async def create_table_in_snowflake(connection: SnowflakeConnection, table_name: str) -> None:
"""Asynchronously create the table if it doesn't exist.
Note that we use the same schema as the snowflake-plugin for backwards compatibility."""
await execute_async_query(
conn,
connection,
f"""
CREATE TABLE IF NOT EXISTS "{database}"."{schema}"."{table_name}" (
CREATE TABLE IF NOT EXISTS "{table_name}" (
"uuid" STRING,
"event" STRING,
"properties" VARIANT,
Expand All @@ -149,33 +160,40 @@ async def put_file_to_snowflake_table(
connection: SnowflakeConnection,
file: BatchExportTemporaryFile,
table_name: str,
database: str | None = None,
schema: str | None = None,
file_no: int,
):
"""Executes a PUT query using the provided cursor to the provided table_name.
Sadly, Snowflake's execute_async does not work with PUT statements. So, we pass the execute
call to run_in_executor: Since execute ends up boiling down to blocking IO (HTTP request),
the event loop should not be locked up.
We add a file_no to the file_name when executing PUT as Snowflake will reject any files with the same
name. Since batch exports re-use the same file, our name does not change, but we don't want Snowflake
to reject or overwrite our new data.
Args:
connection: A SnowflakeConnection object as produced by snowflake.connector.connect.
file: The name of the local file to PUT.
table_name: The name of the Snowflake table where to PUT the file.
file_no: An int to identify which file number this is.
Raises:
TypeError: If we don't get a tuple back from Snowflake (should never happen).
SnowflakeFileNotUploadedError: If the upload status is not 'UPLOADED'.
"""
file.flush()
file.rewind()

if database is not None and schema is not None:
namespace = f"{database}.{schema}."
else:
namespace = ""
# We comply with the file-like interface of io.IOBase.
# So we ask mypy to be nice with us.
reader = io.BufferedReader(file) # type: ignore
query = f'PUT file://{file.name}_{file_no}.jsonl @%"{table_name}"'
cursor = connection.cursor()

cursor = await execute_async_query(
connection,
f"""
PUT file://{file.name} @%{namespace}"{table_name}"
""",
)
execute_put = functools.partial(cursor.execute, query, file_stream=reader)

loop = asyncio.get_running_loop()
await loop.run_in_executor(None, func=execute_put)

result = cursor.fetchone()
if not isinstance(result, tuple):
Expand All @@ -190,23 +208,17 @@ async def put_file_to_snowflake_table(
async def copy_loaded_files_to_snowflake_table(
connection: SnowflakeConnection,
table_name: str,
database: str | None = None,
schema: str | None = None,
):
if database is not None and schema is not None:
namespace = f"{database}.{schema}."
else:
namespace = ""
query = f"""
COPY INTO "{table_name}"
FILE_FORMAT = (TYPE = 'JSON')
MATCH_BY_COLUMN_NAME = CASE_SENSITIVE
PURGE = TRUE
"""
query_id = await execute_async_query(connection, query)

cursor = await execute_async_query(
connection,
f"""
COPY INTO {namespace}"{table_name}"
FILE_FORMAT = (TYPE = 'JSON')
MATCH_BY_COLUMN_NAME = CASE_SENSITIVE
PURGE = TRUE
""",
)
cursor = connection.cursor()
cursor.get_results_from_sfqid(query_id)
results = cursor.fetchall()

for query_result in results:
Expand All @@ -219,7 +231,7 @@ async def copy_loaded_files_to_snowflake_table(
table_name,
"NO STATUS",
0,
query_result[1] if len(query_result) == 1 else "NO ERROR MESSAGE",
query_result[0] if len(query_result) == 1 else "NO ERROR MESSAGE",
)

_, status = query_result[0:2]
Expand Down Expand Up @@ -271,7 +283,9 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs):
logger.info("BatchExporting %s rows to Snowflake", count)

with snowflake_connection(inputs) as connection:
await create_table_in_snowflake(connection, inputs.database, inputs.schema, inputs.table_name)
use_namespace(connection, inputs.database, inputs.schema)

await create_table_in_snowflake(connection, inputs.table_name)

results_iterator = get_results_iterator(
client=client,
Expand All @@ -284,6 +298,7 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs):

result = None
last_uploaded_file_timestamp = None
file_no = 0

async def worker_shutdown_handler():
"""Handle the Worker shutting down by heart-beating our latest status."""
Expand All @@ -297,7 +312,20 @@ async def worker_shutdown_handler():

with BatchExportTemporaryFile() as local_results_file:
for result in results_iterator:
local_results_file.write_records_to_jsonl([result])
record = {
"uuid": result["uuid"],
"event": result["event"],
"properties": result["properties"],
"elements": result["elements"],
"people_set": result["set"],
"people_set_once": result["set_once"],
"distinct_id": result["distinct_id"],
"team_id": result["team_id"],
"ip": result["ip"],
"site_url": result["site_url"],
"timestamp": result["timestamp"],
}
local_results_file.write_records_to_jsonl([record])

if local_results_file.tell() > settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES:
logger.info(
Expand All @@ -307,12 +335,13 @@ async def worker_shutdown_handler():
)

await put_file_to_snowflake_table(
connection, local_results_file, inputs.table_name, inputs.database, inputs.schema
connection, local_results_file, inputs.table_name, file_no=file_no
)

last_uploaded_file_timestamp = result["inserted_at"]
activity.heartbeat(last_uploaded_file_timestamp)
activity.heartbeat(last_uploaded_file_timestamp, file_no)

file_no += 1
local_results_file.reset()

if local_results_file.tell() > 0 and result is not None:
Expand All @@ -321,16 +350,14 @@ async def worker_shutdown_handler():
local_results_file.records_since_last_reset,
local_results_file.bytes_since_last_reset,
)
last_uploaded_file_timestamp = result["inserted_at"]

await put_file_to_snowflake_table(
connection, local_results_file, inputs.table_name, inputs.database, inputs.schema
connection, local_results_file, inputs.table_name, file_no=file_no
)

last_uploaded_file_timestamp = result["inserted_at"]
activity.heartbeat(last_uploaded_file_timestamp)
activity.heartbeat(last_uploaded_file_timestamp, file_no)

await copy_loaded_files_to_snowflake_table(connection, inputs.table_name, inputs.database, inputs.schema)
await copy_loaded_files_to_snowflake_table(connection, inputs.table_name)


@workflow.defn(name="snowflake-export")
Expand Down Expand Up @@ -411,6 +438,4 @@ async def run(self, inputs: SnowflakeBatchExportInputs):
"ForbiddenError",
],
update_inputs=update_inputs,
# Disable heartbeat timeout until we add heartbeat support.
heartbeat_timeout_seconds=None,
)

0 comments on commit 6b047de

Please sign in to comment.