diff --git a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py index ddc9c044554c6f..5fb5fd817cf134 100644 --- a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py @@ -5,6 +5,7 @@ import os import random import re +import unittest.mock from collections import deque from uuid import uuid4 @@ -22,7 +23,6 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker -from posthog.temporal.tests.utils.datetimes import to_isoformat from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.tests.utils.models import acreate_batch_export, adelete_batch_export, afetch_batch_export_runs from posthog.temporal.workflows.batch_exports import ( @@ -39,6 +39,92 @@ pytestmark = [pytest.mark.asyncio, pytest.mark.django_db] +class FakeSnowflakeCursor: + """A fake Snowflake cursor that can fail on PUT and COPY queries.""" + + def __init__(self, *args, failure_mode: str | None = None, **kwargs): + self._execute_calls = [] + self._execute_async_calls = [] + self._sfqid = 1 + self._fail = failure_mode + + @property + def sfqid(self): + current = self._sfqid + self._sfqid += 1 + return current + + def execute(self, query, params=None, file_stream=None): + self._execute_calls.append({"query": query, "params": params, "file_stream": file_stream}) + + def execute_async(self, query, params=None, file_stream=None): + self._execute_async_calls.append({"query": query, "params": params, "file_stream": file_stream}) + + def get_results_from_sfqid(self, query_id): + pass + + def fetchone(self): + if self._fail == "put": + return ( + "test", + "test.gz", + 456, + 0, + "NONE", + "GZIP", + "FAILED", + "Some error on put", + ) + else: + return ( + "test", + "test.gz", + 456, + 0, + "NONE", + "GZIP", + "UPLOADED", + None, + ) + + def fetchall(self): + if self._fail == "copy": + return [("test", "LOAD FAILED", 100, 99, 1, 1, "Some error on copy", 3)] + else: + return [("test", "LOADED", 100, 99, 1, 1, "Some error on copy", 3)] + + +class FakeSnowflakeConnection: + def __init__( + self, + *args, + failure_mode: str | None = None, + **kwargs, + ): + self._cursors = [] + self._is_running = True + self.failure_mode = failure_mode + + def cursor(self) -> FakeSnowflakeCursor: + cursor = FakeSnowflakeCursor(failure_mode=self.failure_mode) + self._cursors.append(cursor) + return cursor + + def get_query_status_throw_if_error(self, query_id): + return snowflake.connector.constants.QueryStatus.SUCCESS + + def is_still_running(self, status): + current_status = self._is_running + self._is_running = not current_status + return current_status + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + def contains_queries_in_order(queries: list[str], *queries_to_find: str): """Check if a list of queries contains a list of queries in order.""" # We use a deque to pop the queries we find off the list of queries to @@ -272,7 +358,9 @@ async def snowflake_batch_export(ateam, table_name, snowflake_config, interval, @pytest.mark.parametrize("interval", ["hour", "day"], indirect=True) -async def test_snowflake_export_workflow_exports_events(ateam, clickhouse_client, snowflake_batch_export, interval): +async def test_snowflake_export_workflow_exports_events( + ateam, clickhouse_client, database, schema, snowflake_batch_export, interval, table_name +): """Test that the whole workflow not just the activity works. It should update the batch export run status to completed, as well as updating the record @@ -281,7 +369,7 @@ async def test_snowflake_export_workflow_exports_events(ateam, clickhouse_client data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") data_interval_start = data_interval_end - snowflake_batch_export.interval_time_delta - (events, _, _) = await generate_test_events_in_clickhouse( + await generate_test_events_in_clickhouse( client=clickhouse_client, team_id=ateam.pk, start_time=data_interval_start, @@ -315,10 +403,12 @@ async def test_snowflake_export_workflow_exports_events(ateam, clickhouse_client ], workflow_runner=UnsandboxedWorkflowRunner(), ): - with responses.RequestsMock( - target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send" - ) as rsps, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2): - queries, staged_files = add_mock_snowflake_api(rsps) + with unittest.mock.patch( + "posthog.temporal.workflows.snowflake_batch_export.snowflake.connector.connect", + ) as mock, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1): + fake_conn = FakeSnowflakeConnection() + mock.return_value = fake_conn + await activity_environment.client.execute_workflow( SnowflakeBatchExportWorkflow.run, inputs, @@ -328,49 +418,27 @@ async def test_snowflake_export_workflow_exports_events(ateam, clickhouse_client retry_policy=RetryPolicy(maximum_attempts=1), ) - assert contains_queries_in_order( - queries, - 'USE DATABASE "PostHog"', - 'USE SCHEMA "test"', - 'CREATE TABLE IF NOT EXISTS "PostHog"."test"."events"', - # NOTE: we check that we at least have two PUT queries to - # ensure we hit the multi file upload code path - 'PUT file://.* @%"events"', - 'PUT file://.* @%"events"', - 'COPY INTO "events"', - ) + execute_calls = [] + for cursor in fake_conn._cursors: + for call in cursor._execute_calls: + execute_calls.append(call["query"]) - staged_data = "\n".join(staged_files) + execute_async_calls = [] + for cursor in fake_conn._cursors: + for call in cursor._execute_async_calls: + execute_async_calls.append(call["query"]) - # Check that the data is correct. - json_data = [json.loads(line) for line in staged_data.split("\n") if line] - # Pull out the fields we inserted only - json_data = [ - { - "uuid": event["uuid"], - "event": event["event"], - "timestamp": event["timestamp"], - "properties": event["properties"], - "person_id": event["person_id"], - } - for event in json_data + assert execute_calls[0:3] == [ + "SET ABORT_DETACHED_QUERY = FALSE", + f'USE DATABASE "{database}"', + f'USE SCHEMA "{schema}"', ] - json_data.sort(key=lambda x: x["timestamp"]) - # Drop _timestamp and team_id from events - expected_events = [] - for event in events: - expected_event = { - key: value - for key, value in event.items() - if key in ("uuid", "event", "timestamp", "properties", "person_id") - } - expected_event["timestamp"] = to_isoformat(event["timestamp"]) - expected_events.append(expected_event) - expected_events.sort(key=lambda x: x["timestamp"]) + assert all(query.startswith("PUT") for query in execute_calls[3:12]) + assert all(f"_{n}.jsonl" in query for n, query in enumerate(execute_calls[3:12])) - assert json_data[0] == expected_events[0] - assert json_data == expected_events + assert execute_async_calls[0].strip().startswith(f'CREATE TABLE IF NOT EXISTS "{table_name}"') + assert execute_async_calls[1].strip().startswith(f'COPY INTO "{table_name}"') runs = await afetch_batch_export_runs(batch_export_id=snowflake_batch_export.id) assert len(runs) == 1 @@ -485,11 +553,15 @@ async def test_snowflake_export_workflow_raises_error_on_put_fail( ], workflow_runner=UnsandboxedWorkflowRunner(), ): - with responses.RequestsMock( - target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send" - ) as rsps, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2): - add_mock_snowflake_api(rsps, fail="put") + class FakeSnowflakeConnectionFailOnPut(FakeSnowflakeConnection): + def __init__(self, *args, **kwargs): + super().__init__(*args, failure_mode="put", **kwargs) + + with unittest.mock.patch( + "posthog.temporal.workflows.snowflake_batch_export.snowflake.connector.connect", + side_effect=FakeSnowflakeConnectionFailOnPut, + ): with pytest.raises(WorkflowFailureError) as exc_info: await activity_environment.client.execute_workflow( SnowflakeBatchExportWorkflow.run, @@ -547,11 +619,15 @@ async def test_snowflake_export_workflow_raises_error_on_copy_fail( ], workflow_runner=UnsandboxedWorkflowRunner(), ): - with responses.RequestsMock( - target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send" - ) as rsps, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2): - add_mock_snowflake_api(rsps, fail="copy") + class FakeSnowflakeConnectionFailOnCopy(FakeSnowflakeConnection): + def __init__(self, *args, **kwargs): + super().__init__(*args, failure_mode="copy", **kwargs) + + with unittest.mock.patch( + "posthog.temporal.workflows.snowflake_batch_export.snowflake.connector.connect", + side_effect=FakeSnowflakeConnectionFailOnCopy, + ): with pytest.raises(WorkflowFailureError) as exc_info: await activity_environment.client.execute_workflow( SnowflakeBatchExportWorkflow.run, diff --git a/posthog/temporal/workflows/snowflake_batch_export.py b/posthog/temporal/workflows/snowflake_batch_export.py index 4428704706bc99..c45887bde2948f 100644 --- a/posthog/temporal/workflows/snowflake_batch_export.py +++ b/posthog/temporal/workflows/snowflake_batch_export.py @@ -114,7 +114,7 @@ async def execute_async_query( # Snowflake docs incorrectly state that the 'params' argument is named 'parameters'. result = cursor.execute_async(query, params=parameters, file_stream=file_stream) - query_id = result["queryId"] + query_id = cursor.sfqid or result["queryId"] query_status = None try: