Skip to content

Commit

Permalink
test: Update mock tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Nov 7, 2023
1 parent 84d50ad commit 1e80500
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import random
import re
import unittest.mock
from collections import deque
from uuid import uuid4

Expand All @@ -22,7 +23,6 @@
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import UnsandboxedWorkflowRunner, Worker

from posthog.temporal.tests.utils.datetimes import to_isoformat
from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse
from posthog.temporal.tests.utils.models import acreate_batch_export, adelete_batch_export, afetch_batch_export_runs
from posthog.temporal.workflows.batch_exports import (
Expand All @@ -39,6 +39,92 @@
pytestmark = [pytest.mark.asyncio, pytest.mark.django_db]


class FakeSnowflakeCursor:
"""A fake Snowflake cursor that can fail on PUT and COPY queries."""

def __init__(self, *args, failure_mode: str | None = None, **kwargs):
self._execute_calls = []
self._execute_async_calls = []
self._sfqid = 1
self._fail = failure_mode

@property
def sfqid(self):
current = self._sfqid
self._sfqid += 1
return current

def execute(self, query, params=None, file_stream=None):
self._execute_calls.append({"query": query, "params": params, "file_stream": file_stream})

def execute_async(self, query, params=None, file_stream=None):
self._execute_async_calls.append({"query": query, "params": params, "file_stream": file_stream})

def get_results_from_sfqid(self, query_id):
pass

def fetchone(self):
if self._fail == "put":
return (
"test",
"test.gz",
456,
0,
"NONE",
"GZIP",
"FAILED",
"Some error on put",
)
else:
return (
"test",
"test.gz",
456,
0,
"NONE",
"GZIP",
"UPLOADED",
None,
)

def fetchall(self):
if self._fail == "copy":
return [("test", "LOAD FAILED", 100, 99, 1, 1, "Some error on copy", 3)]
else:
return [("test", "LOADED", 100, 99, 1, 1, "Some error on copy", 3)]


class FakeSnowflakeConnection:
def __init__(
self,
*args,
failure_mode: str | None = None,
**kwargs,
):
self._cursors = []
self._is_running = True
self.failure_mode = failure_mode

def cursor(self) -> FakeSnowflakeCursor:
cursor = FakeSnowflakeCursor(failure_mode=self.failure_mode)
self._cursors.append(cursor)
return cursor

def get_query_status_throw_if_error(self, query_id):
return snowflake.connector.constants.QueryStatus.SUCCESS

def is_still_running(self, status):
current_status = self._is_running
self._is_running = not current_status
return current_status

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
pass


def contains_queries_in_order(queries: list[str], *queries_to_find: str):
"""Check if a list of queries contains a list of queries in order."""
# We use a deque to pop the queries we find off the list of queries to
Expand Down Expand Up @@ -272,7 +358,9 @@ async def snowflake_batch_export(ateam, table_name, snowflake_config, interval,


@pytest.mark.parametrize("interval", ["hour", "day"], indirect=True)
async def test_snowflake_export_workflow_exports_events(ateam, clickhouse_client, snowflake_batch_export, interval):
async def test_snowflake_export_workflow_exports_events(
ateam, clickhouse_client, database, schema, snowflake_batch_export, interval, table_name
):
"""Test that the whole workflow not just the activity works.
It should update the batch export run status to completed, as well as updating the record
Expand All @@ -281,7 +369,7 @@ async def test_snowflake_export_workflow_exports_events(ateam, clickhouse_client
data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00")
data_interval_start = data_interval_end - snowflake_batch_export.interval_time_delta

(events, _, _) = await generate_test_events_in_clickhouse(
await generate_test_events_in_clickhouse(
client=clickhouse_client,
team_id=ateam.pk,
start_time=data_interval_start,
Expand Down Expand Up @@ -315,10 +403,12 @@ async def test_snowflake_export_workflow_exports_events(ateam, clickhouse_client
],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with responses.RequestsMock(
target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send"
) as rsps, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2):
queries, staged_files = add_mock_snowflake_api(rsps)
with unittest.mock.patch(
"posthog.temporal.workflows.snowflake_batch_export.snowflake.connector.connect",
) as mock, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1):
fake_conn = FakeSnowflakeConnection()
mock.return_value = fake_conn

await activity_environment.client.execute_workflow(
SnowflakeBatchExportWorkflow.run,
inputs,
Expand All @@ -328,49 +418,27 @@ async def test_snowflake_export_workflow_exports_events(ateam, clickhouse_client
retry_policy=RetryPolicy(maximum_attempts=1),
)

assert contains_queries_in_order(
queries,
'USE DATABASE "PostHog"',
'USE SCHEMA "test"',
'CREATE TABLE IF NOT EXISTS "PostHog"."test"."events"',
# NOTE: we check that we at least have two PUT queries to
# ensure we hit the multi file upload code path
'PUT file://.* @%"events"',
'PUT file://.* @%"events"',
'COPY INTO "events"',
)
execute_calls = []
for cursor in fake_conn._cursors:
for call in cursor._execute_calls:
execute_calls.append(call["query"])

staged_data = "\n".join(staged_files)
execute_async_calls = []
for cursor in fake_conn._cursors:
for call in cursor._execute_async_calls:
execute_async_calls.append(call["query"])

# Check that the data is correct.
json_data = [json.loads(line) for line in staged_data.split("\n") if line]
# Pull out the fields we inserted only
json_data = [
{
"uuid": event["uuid"],
"event": event["event"],
"timestamp": event["timestamp"],
"properties": event["properties"],
"person_id": event["person_id"],
}
for event in json_data
assert execute_calls[0:3] == [
"SET ABORT_DETACHED_QUERY = FALSE",
f'USE DATABASE "{database}"',
f'USE SCHEMA "{schema}"',
]
json_data.sort(key=lambda x: x["timestamp"])

# Drop _timestamp and team_id from events
expected_events = []
for event in events:
expected_event = {
key: value
for key, value in event.items()
if key in ("uuid", "event", "timestamp", "properties", "person_id")
}
expected_event["timestamp"] = to_isoformat(event["timestamp"])
expected_events.append(expected_event)
expected_events.sort(key=lambda x: x["timestamp"])
assert all(query.startswith("PUT") for query in execute_calls[3:12])
assert all(f"_{n}.jsonl" in query for n, query in enumerate(execute_calls[3:12]))

assert json_data[0] == expected_events[0]
assert json_data == expected_events
assert execute_async_calls[0].strip().startswith(f'CREATE TABLE IF NOT EXISTS "{table_name}"')
assert execute_async_calls[1].strip().startswith(f'COPY INTO "{table_name}"')

runs = await afetch_batch_export_runs(batch_export_id=snowflake_batch_export.id)
assert len(runs) == 1
Expand Down Expand Up @@ -485,11 +553,15 @@ async def test_snowflake_export_workflow_raises_error_on_put_fail(
],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with responses.RequestsMock(
target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send"
) as rsps, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2):
add_mock_snowflake_api(rsps, fail="put")

class FakeSnowflakeConnectionFailOnPut(FakeSnowflakeConnection):
def __init__(self, *args, **kwargs):
super().__init__(*args, failure_mode="put", **kwargs)

with unittest.mock.patch(
"posthog.temporal.workflows.snowflake_batch_export.snowflake.connector.connect",
side_effect=FakeSnowflakeConnectionFailOnPut,
):
with pytest.raises(WorkflowFailureError) as exc_info:
await activity_environment.client.execute_workflow(
SnowflakeBatchExportWorkflow.run,
Expand Down Expand Up @@ -547,11 +619,15 @@ async def test_snowflake_export_workflow_raises_error_on_copy_fail(
],
workflow_runner=UnsandboxedWorkflowRunner(),
):
with responses.RequestsMock(
target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send"
) as rsps, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2):
add_mock_snowflake_api(rsps, fail="copy")

class FakeSnowflakeConnectionFailOnCopy(FakeSnowflakeConnection):
def __init__(self, *args, **kwargs):
super().__init__(*args, failure_mode="copy", **kwargs)

with unittest.mock.patch(
"posthog.temporal.workflows.snowflake_batch_export.snowflake.connector.connect",
side_effect=FakeSnowflakeConnectionFailOnCopy,
):
with pytest.raises(WorkflowFailureError) as exc_info:
await activity_environment.client.execute_workflow(
SnowflakeBatchExportWorkflow.run,
Expand Down
2 changes: 1 addition & 1 deletion posthog/temporal/workflows/snowflake_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def execute_async_query(

# Snowflake docs incorrectly state that the 'params' argument is named 'parameters'.
result = cursor.execute_async(query, params=parameters, file_stream=file_stream)
query_id = result["queryId"]
query_id = cursor.sfqid or result["queryId"]
query_status = None

try:
Expand Down

0 comments on commit 1e80500

Please sign in to comment.