Skip to content

Commit

Permalink
test: Add proper Snowflake tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Nov 6, 2023
1 parent 4592f3f commit 84d50ad
Showing 1 changed file with 295 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import datetime as dt
import gzip
import json
import os
import random
import re
from collections import deque
from uuid import uuid4

import pytest
import pytest_asyncio
import responses
import snowflake.connector
from django.conf import settings
from django.test import override_settings
from requests.models import PreparedRequest
Expand Down Expand Up @@ -204,21 +207,52 @@ def query_request_handler(request: PreparedRequest):
return queries, staged_files


@pytest.fixture
def database():
"""Generate a unique database name for tests."""
return f"test_batch_exports_{uuid4()}"


@pytest.fixture
def schema():
"""Generate a unique schema name for tests."""
return f"test_batch_exports_{uuid4()}"


@pytest.fixture
def table_name(ateam, interval):
return f"test_workflow_table_{ateam.pk}_{interval}"


@pytest.fixture
def snowflake_config(database, schema) -> dict[str, str]:
"""Return a Snowflake configuration dictionary to use in tests.
We set default configuration values to support tests against the Snowflake API
and tests that mock it.
"""
password = os.getenv("SNOWFLAKE_PASSWORD", "password")
warehouse = os.getenv("SNOWFLAKE_WAREHOUSE", "COMPUTE_WH")
account = os.getenv("SNOWFLAKE_ACCOUNT", "account")
username = os.getenv("SNOWFLAKE_USERNAME", "hazzadous")

return {
"password": password,
"user": username,
"warehouse": warehouse,
"account": account,
"database": database,
"schema": schema,
}


@pytest_asyncio.fixture
async def snowflake_batch_export(ateam, interval, temporal_client):
async def snowflake_batch_export(ateam, table_name, snowflake_config, interval, exclude_events, temporal_client):
"""Manage BatchExport model (and associated Temporal Schedule) for tests"""
destination_data = {
"type": "Snowflake",
"config": {
"user": "hazzadous",
"password": "password",
"account": "account",
"database": "PostHog",
"schema": "test",
"warehouse": "COMPUTE_WH",
"table_name": "events",
},
"config": {**snowflake_config, "table_name": table_name, "exclude_events": exclude_events},
}

batch_export_data = {
"name": "my-production-snowflake-export",
"destination": destination_data,
Expand Down Expand Up @@ -624,3 +658,253 @@ async def never_finish_activity(_: SnowflakeInsertInputs) -> str:
run = runs[0]
assert run.status == "Cancelled"
assert run.latest_error == "Cancelled"


def assert_events_in_snowflake(
cursor: snowflake.connector.cursor.SnowflakeCursor, table_name: str, events: list, exclude_events: list[str]
):
"""Assert provided events are present in Snowflake table."""
cursor.execute(f'SELECT * FROM "{table_name}"')

rows = cursor.fetchall()

columns = {index: metadata.name for index, metadata in enumerate(cursor.description)}
json_columns = ("properties", "elements", "people_set", "people_set_once")

# Rows are tuples, so we construct a dictionary using the metadata from cursor.description.
# We rely on the order of the columns in each row matching the order set in cursor.description.
# This seems to be the case, at least for now.
inserted_events = [
{
columns[index]: json.loads(row[index])
if columns[index] in json_columns and row[index] is not None
else row[index]
for index in columns.keys()
}
for row in rows
]
inserted_events.sort(key=lambda x: (x["event"], x["timestamp"]))

expected_events = []
for event in events:
event_name = event.get("event")

if exclude_events is not None and event_name in exclude_events:
continue

properties = event.get("properties", None)
elements_chain = event.get("elements_chain", None)
expected_event = {
"distinct_id": event.get("distinct_id"),
"elements": json.dumps(elements_chain),
"event": event_name,
"ip": properties.get("$ip", None) if properties else None,
"properties": event.get("properties"),
"people_set": properties.get("$set", None) if properties else None,
"people_set_once": properties.get("$set_once", None) if properties else None,
"site_url": "",
"timestamp": dt.datetime.fromisoformat(event.get("timestamp")),
"team_id": event.get("team_id"),
"uuid": event.get("uuid"),
}
expected_events.append(expected_event)

expected_events.sort(key=lambda x: (x["event"], x["timestamp"]))

assert inserted_events[0] == expected_events[0]
assert inserted_events == expected_events


REQUIRED_ENV_VARS = (
"SNOWFLAKE_WAREHOUSE",
"SNOWFLAKE_PASSWORD",
"SNOWFLAKE_ACCOUNT",
"SNOWFLAKE_USERNAME",
)

SKIP_IF_MISSING_REQUIRED_ENV_VARS = pytest.mark.skipif(
any(env_var not in os.environ for env_var in REQUIRED_ENV_VARS),
reason="Snowflake required env vars are not set",
)


@pytest.fixture
def snowflake_cursor(snowflake_config):
"""Manage a snowflake cursor that cleans up after we are done."""
with snowflake.connector.connect(
user=snowflake_config["user"],
password=snowflake_config["password"],
account=snowflake_config["account"],
warehouse=snowflake_config["warehouse"],
) as connection:
cursor = connection.cursor()
cursor.execute(f"CREATE DATABASE \"{snowflake_config['database']}\"")
cursor.execute(f"CREATE SCHEMA \"{snowflake_config['database']}\".\"{snowflake_config['schema']}\"")
cursor.execute(f"USE SCHEMA \"{snowflake_config['database']}\".\"{snowflake_config['schema']}\"")

yield cursor

cursor.execute(f"DROP DATABASE IF EXISTS \"{snowflake_config['database']}\" CASCADE")


@SKIP_IF_MISSING_REQUIRED_ENV_VARS
@pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True)
async def test_insert_into_snowflake_activity_inserts_data_into_snowflake_table(
clickhouse_client, activity_environment, snowflake_cursor, snowflake_config, exclude_events
):
"""Test that the insert_into_snowflake_activity function inserts data into a PostgreSQL table.
We use the generate_test_events_in_clickhouse function to generate several sets
of events. Some of these sets are expected to be exported, and others not. Expected
events are those that:
* Are created for the team_id of the batch export.
* Are created in the date range of the batch export.
* Are not duplicates of other events that are in the same batch.
* Do not have an event name contained in the batch export's exclude_events.
Once we have these events, we pass them to the assert_events_in_snowflake function to check
that they appear in the expected Snowflake table. This function runs against a real Snowflake
instance, so the environment should be populated with the necessary credentials.
"""
data_interval_start = dt.datetime(2023, 4, 20, 14, 0, 0, tzinfo=dt.timezone.utc)
data_interval_end = dt.datetime(2023, 4, 25, 15, 0, 0, tzinfo=dt.timezone.utc)

team_id = random.randint(1, 1000000)
(events, _, _) = await generate_test_events_in_clickhouse(
client=clickhouse_client,
team_id=team_id,
start_time=data_interval_start,
end_time=data_interval_end,
count=1000,
count_outside_range=10,
count_other_team=10,
duplicate=True,
properties={"$browser": "Chrome", "$os": "Mac OS X"},
person_properties={"utm_medium": "referral", "$initial_os": "Linux"},
)

events_to_exclude = []
if exclude_events:
for event_name in exclude_events:
(events_to_exclude_for_event_name, _, _) = await generate_test_events_in_clickhouse(
client=clickhouse_client,
team_id=team_id,
start_time=data_interval_start,
end_time=data_interval_end,
count=5,
count_outside_range=0,
count_other_team=0,
event_name=event_name,
)
events_to_exclude += events_to_exclude_for_event_name

table_name = f"test_insert_activity_table_{team_id}"
insert_inputs = SnowflakeInsertInputs(
team_id=team_id,
table_name=table_name,
data_interval_start=data_interval_start.isoformat(),
data_interval_end=data_interval_end.isoformat(),
exclude_events=exclude_events,
**snowflake_config,
)

await activity_environment.run(insert_into_snowflake_activity, insert_inputs)

assert_events_in_snowflake(
cursor=snowflake_cursor,
table_name=table_name,
events=events + events_to_exclude,
exclude_events=exclude_events,
)


@SKIP_IF_MISSING_REQUIRED_ENV_VARS
@pytest.mark.parametrize("interval", ["hour", "day"], indirect=True)
@pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True)
async def test_snowflake_export_workflow(
clickhouse_client,
snowflake_cursor,
interval,
snowflake_batch_export,
ateam,
exclude_events,
):
"""Test Redshift Export Workflow end-to-end.
The workflow should update the batch export run status to completed and produce the expected
records to the provided Redshift instance.
"""
data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00")
data_interval_start = data_interval_end - snowflake_batch_export.interval_time_delta

(events, _, _) = await generate_test_events_in_clickhouse(
client=clickhouse_client,
team_id=ateam.pk,
start_time=data_interval_start,
end_time=data_interval_end,
count=100,
count_outside_range=10,
count_other_team=10,
duplicate=True,
properties={"$browser": "Chrome", "$os": "Mac OS X"},
person_properties={"utm_medium": "referral", "$initial_os": "Linux"},
)

events_to_exclude = []
if exclude_events:
for event_name in exclude_events:
(events_to_exclude_for_event_name, _, _) = await generate_test_events_in_clickhouse(
client=clickhouse_client,
team_id=ateam.pk,
start_time=data_interval_start,
end_time=data_interval_end,
count=5,
count_outside_range=0,
count_other_team=0,
event_name=event_name,
)
events_to_exclude += events_to_exclude_for_event_name

workflow_id = str(uuid4())
inputs = SnowflakeBatchExportInputs(
team_id=ateam.pk,
batch_export_id=str(snowflake_batch_export.id),
data_interval_end="2023-04-25 14:30:00.000000",
interval=interval,
**snowflake_batch_export.destination.config,
)

async with await WorkflowEnvironment.start_time_skipping() as activity_environment:
async with Worker(
activity_environment.client,
task_queue=settings.TEMPORAL_TASK_QUEUE,
workflows=[SnowflakeBatchExportWorkflow],
activities=[
create_export_run,
insert_into_snowflake_activity,
update_export_run_status,
],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with override_settings(BATCH_EXPORT_REDSHIFT_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2):
await activity_environment.client.execute_workflow(
SnowflakeBatchExportWorkflow.run,
inputs,
id=workflow_id,
task_queue=settings.TEMPORAL_TASK_QUEUE,
retry_policy=RetryPolicy(maximum_attempts=1),
execution_timeout=dt.timedelta(seconds=10),
)

runs = await afetch_batch_export_runs(batch_export_id=snowflake_batch_export.id)
assert len(runs) == 1

run = runs[0]
assert run.status == "Completed"

assert_events_in_snowflake(
cursor=snowflake_cursor,
table_name=snowflake_batch_export.destination.config["table_name"],
events=events + events_to_exclude,
exclude_events=exclude_events,
)

0 comments on commit 84d50ad

Please sign in to comment.