Skip to content

Commit

Permalink
refactor: Snowflake batch export is now async
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Nov 6, 2023
1 parent 9a826ae commit 48a3046
Showing 1 changed file with 193 additions and 135 deletions.
328 changes: 193 additions & 135 deletions posthog/temporal/workflows/snowflake_batch_export.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import asyncio
import contextlib
import datetime as dt
import json
import tempfile
import typing
from dataclasses import dataclass

import snowflake.connector
from django.conf import settings
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
from temporalio import activity, workflow
from temporalio.common import RetryPolicy

from posthog.batch_exports.service import SnowflakeBatchExportInputs
from posthog.temporal.workflows.base import PostHogWorkflow
from posthog.temporal.workflows.batch_exports import (
BatchExportTemporaryFile,
CreateBatchExportRunInputs,
UpdateBatchExportRunStatusInputs,
create_export_run,
Expand Down Expand Up @@ -65,23 +69,114 @@ class SnowflakeInsertInputs:
include_events: list[str] | None = None


def put_file_to_snowflake_table(cursor: SnowflakeCursor, file_name: str, table_name: str):
@contextlib.contextmanager
def snowflake_connection(inputs) -> typing.Generator[SnowflakeConnection, None, None]:
with snowflake.connector.connect(
user=inputs.user,
password=inputs.password,
account=inputs.account,
warehouse=inputs.warehouse,
database=inputs.database,
schema=inputs.schema,
role=inputs.role,
) as connection:
yield connection


async def execute_async_query(
connection: SnowflakeConnection, query: str, parameters: dict | None = None, poll_interval: float = 1.0
) -> SnowflakeCursor:
"""Wrap Snowflake connector's polling API in a coroutine.
This enables asynchronous execution of queries to release the event loop to execute other tasks
while we poll for a query to be done. For example, the event loop may use this time for heartbeating.
Args:
connection: A SnowflakeConnection object as produced by snowflake.connector.connect.
query: A query string to run asynchronously.
parameters: An optional dictionary of parameters to bind to the query.
poll_interval: Specify how long to wait in between polls.
"""
cursor = connection.cursor()

cursor.execute_async(query, parameters or {})

try:
query_id = cursor.sfqid

while connection.is_still_running(connection.get_query_status_throw_if_error(query_id)):
await asyncio.sleep(poll_interval)

except snowflake.connector.ProgrammingError:
# TODO: logging
raise

else:
cursor.get_results_from_sfqid(query_id)

finally:
pass

return cursor


async def create_table_in_snowflake(conn: SnowflakeConnection, database: str, schema: str, table_name: str) -> None:
"""Asynchronously create the table if it doesn't exist.
Note that we use the same schema as the snowflake-plugin for backwards compatibility."""
await execute_async_query(
conn,
f"""
CREATE TABLE IF NOT EXISTS "{database}"."{schema}"."{table_name}" (
"uuid" STRING,
"event" STRING,
"properties" VARIANT,
"elements" VARIANT,
"people_set" VARIANT,
"people_set_once" VARIANT,
"distinct_id" STRING,
"team_id" INTEGER,
"ip" STRING,
"site_url" STRING,
"timestamp" TIMESTAMP
)
COMMENT = 'PostHog generated events table'
""",
)


async def put_file_to_snowflake_table(
connection: SnowflakeConnection,
file: BatchExportTemporaryFile,
table_name: str,
database: str | None = None,
schema: str | None = None,
):
"""Executes a PUT query using the provided cursor to the provided table_name.
Args:
cursor: A Snowflake cursor to execute the PUT query.
file_name: The name of the file to PUT.
table_name: The name of the table where to PUT the file.
connection: A SnowflakeConnection object as produced by snowflake.connector.connect.
file: The name of the local file to PUT.
table_name: The name of the Snowflake table where to PUT the file.
Raises:
TypeError: If we don't get a tuple back from Snowflake (should never happen).
SnowflakeFileNotUploadedError: If the upload status is not 'UPLOADED'.
"""
cursor.execute(
file.flush()

if database is not None and schema is not None:
namespace = f"{database}.{schema}."
else:
namespace = ""

cursor = await execute_async_query(
connection,
f"""
PUT file://{file_name} @%"{table_name}"
"""
PUT file://{file.name} @%{namespace}"{table_name}"
""",
)

result = cursor.fetchone()
if not isinstance(result, tuple):
# Mostly to appease mypy, as this query should always return a tuple.
Expand All @@ -92,6 +187,53 @@ def put_file_to_snowflake_table(cursor: SnowflakeCursor, file_name: str, table_n
raise SnowflakeFileNotUploadedError(table_name, status, message)


async def copy_loaded_files_to_snowflake_table(
connection: SnowflakeConnection,
table_name: str,
database: str | None = None,
schema: str | None = None,
):
if database is not None and schema is not None:
namespace = f"{database}.{schema}."
else:
namespace = ""

cursor = await execute_async_query(
connection,
f"""
COPY INTO {namespace}"{table_name}"
FILE_FORMAT = (TYPE = 'JSON')
MATCH_BY_COLUMN_NAME = CASE_SENSITIVE
PURGE = TRUE
""",
)
results = cursor.fetchall()

for query_result in results:
if not isinstance(query_result, tuple):
# Mostly to appease mypy, as this query should always return a tuple.
raise TypeError(f"Expected tuple from Snowflake COPY INTO query but got: '{type(query_result)}'")

if len(query_result) < 2:
raise SnowflakeFileNotLoadedError(
table_name,
"NO STATUS",
0,
query_result[1] if len(query_result) == 1 else "NO ERROR MESSAGE",
)

_, status = query_result[0:2]

if status != "LOADED":
errors_seen, first_error = query_result[5:7]
raise SnowflakeFileNotLoadedError(
table_name,
status or "NO STATUS",
errors_seen or 0,
first_error or "NO ERROR MESSAGE",
)


@activity.defn
async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs):
"""Activity streams data from ClickHouse to Snowflake.
Expand Down Expand Up @@ -128,41 +270,8 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs):

logger.info("BatchExporting %s rows to Snowflake", count)

conn = snowflake.connector.connect(
user=inputs.user,
password=inputs.password,
account=inputs.account,
warehouse=inputs.warehouse,
database=inputs.database,
schema=inputs.schema,
role=inputs.role,
)

try:
cursor = conn.cursor()
cursor.execute(f'USE DATABASE "{inputs.database}"')
cursor.execute(f'USE SCHEMA "{inputs.schema}"')

# Create the table if it doesn't exist. Note that we use the same schema
# as the snowflake-plugin for backwards compatibility.
cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS "{inputs.database}"."{inputs.schema}"."{inputs.table_name}" (
"uuid" STRING,
"event" STRING,
"properties" VARIANT,
"elements" VARIANT,
"people_set" VARIANT,
"people_set_once" VARIANT,
"distinct_id" STRING,
"team_id" INTEGER,
"ip" STRING,
"site_url" STRING,
"timestamp" TIMESTAMP
)
COMMENT = 'PostHog generated events table'
"""
)
with snowflake_connection(inputs) as connection:
await create_table_in_snowflake(connection, inputs.database, inputs.schema, inputs.table_name)

results_iterator = get_results_iterator(
client=client,
Expand All @@ -172,107 +281,56 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs):
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
)

result = None
local_results_file = tempfile.NamedTemporaryFile(suffix=".jsonl")
try:
while True:
try:
result = results_iterator.__next__()
last_uploaded_file_timestamp = None

except StopIteration:
break
async def worker_shutdown_handler():
"""Handle the Worker shutting down by heart-beating our latest status."""
await activity.wait_for_worker_shutdown()
logger.warn(
f"Worker shutting down! Reporting back latest exported part {last_uploaded_file_timestamp}",
)
activity.heartbeat(last_uploaded_file_timestamp)

except json.JSONDecodeError:
asyncio.create_task(worker_shutdown_handler())

with BatchExportTemporaryFile() as local_results_file:
for result in results_iterator:
local_results_file.write_records_to_jsonl([result])

if local_results_file.tell() > settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES:
logger.info(
"Failed to decode a JSON value while iterating, potentially due to a ClickHouse error"
)
# This is raised by aiochclient as we try to decode an error message from ClickHouse.
# So far, this error message only indicated that we were too slow consuming rows.
# So, we can resume from the last result.
if result is None:
# We failed right at the beginning
new_interval_start = None
else:
new_interval_start = result.get("inserted_at", None)

if not isinstance(new_interval_start, str):
new_interval_start = inputs.data_interval_start

results_iterator = get_results_iterator(
client=client,
team_id=inputs.team_id,
interval_start=new_interval_start, # This means we'll generate at least one duplicate.
interval_end=inputs.data_interval_end,
"Putting file containing %s records with size %s bytes",
local_results_file.records_since_last_reset,
local_results_file.bytes_since_last_reset,
)
continue

if not result:
break

# Write the results to a local file
local_results_file.write(json.dumps(result).encode("utf-8"))
local_results_file.write("\n".encode("utf-8"))

# Write results to Snowflake when the file reaches 50MB and
# reset the file, or if there is nothing else to write.
if (
local_results_file.tell()
and local_results_file.tell() > settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES
):
logger.info("Uploading to Snowflake")

# Flush the file to make sure everything is written
local_results_file.flush()
put_file_to_snowflake_table(cursor, local_results_file.name, inputs.table_name)

# Delete the temporary file and create a new one
local_results_file.close()
local_results_file = tempfile.NamedTemporaryFile(suffix=".jsonl")

# Flush the file to make sure everything is written
local_results_file.flush()
put_file_to_snowflake_table(cursor, local_results_file.name, inputs.table_name)

# We don't need the file anymore, close (and delete) it.
local_results_file.close()
cursor.execute(
f"""
COPY INTO "{inputs.table_name}"
FILE_FORMAT = (TYPE = 'JSON')
MATCH_BY_COLUMN_NAME = CASE_SENSITIVE
PURGE = TRUE
"""
)
results = cursor.fetchall()

for query_result in results:
if not isinstance(query_result, tuple):
# Mostly to appease mypy, as this query should always return a tuple.
raise TypeError(f"Expected tuple from Snowflake COPY INTO query but got: '{type(result)}'")

if len(query_result) < 2:
raise SnowflakeFileNotLoadedError(
inputs.table_name,
"NO STATUS",
0,
query_result[1] if len(query_result) == 1 else "NO ERROR MESSAGE",

await put_file_to_snowflake_table(
connection, local_results_file, inputs.table_name, inputs.database, inputs.schema
)

_, status = query_result[0:2]
last_uploaded_file_timestamp = result["inserted_at"]
activity.heartbeat(last_uploaded_file_timestamp)

if status != "LOADED":
errors_seen, first_error = query_result[5:7]
raise SnowflakeFileNotLoadedError(
inputs.table_name,
status or "NO STATUS",
errors_seen or 0,
first_error or "NO ERROR MESSAGE",
)
local_results_file.reset()

if local_results_file.tell() > 0 and result is not None:
logger.info(
"Putting last file containing %s records with size %s bytes",
local_results_file.records_since_last_reset,
local_results_file.bytes_since_last_reset,
)
last_uploaded_file_timestamp = result["inserted_at"]

await put_file_to_snowflake_table(
connection, local_results_file, inputs.table_name, inputs.database, inputs.schema
)

last_uploaded_file_timestamp = result["inserted_at"]
activity.heartbeat(last_uploaded_file_timestamp)

finally:
local_results_file.close()
finally:
conn.close()
await copy_loaded_files_to_snowflake_table(connection, inputs.table_name, inputs.database, inputs.schema)


@workflow.defn(name="snowflake-export")
Expand Down

0 comments on commit 48a3046

Please sign in to comment.