diff --git a/.github/actions/run-backend-tests/action.yml b/.github/actions/run-backend-tests/action.yml index edd36992a6614..6471740a9d6c6 100644 --- a/.github/actions/run-backend-tests/action.yml +++ b/.github/actions/run-backend-tests/action.yml @@ -83,7 +83,7 @@ runs: - uses: syphar/restore-virtualenv@v1 id: cache-backend-tests with: - custom_cache_key_element: v1 + custom_cache_key_element: v2 - uses: syphar/restore-pip-download-cache@v1 if: steps.cache-backend-tests.outputs.cache-hit != 'true' diff --git a/.github/workflows/ci-backend.yml b/.github/workflows/ci-backend.yml index f80300fca0a56..4e09f52eda7b6 100644 --- a/.github/workflows/ci-backend.yml +++ b/.github/workflows/ci-backend.yml @@ -109,7 +109,7 @@ jobs: - uses: syphar/restore-virtualenv@v1 id: cache-backend-tests with: - custom_cache_key_element: v1- + custom_cache_key_element: v2- - uses: syphar/restore-pip-download-cache@v1 if: steps.cache-backend-tests.outputs.cache-hit != 'true' @@ -331,7 +331,7 @@ jobs: - uses: syphar/restore-virtualenv@v1 id: cache-backend-tests with: - custom_cache_key_element: v1- + custom_cache_key_element: v2- - uses: syphar/restore-pip-download-cache@v1 if: steps.cache-backend-tests.outputs.cache-hit != 'true' diff --git a/posthog/temporal/tests/batch_exports/conftest.py b/posthog/temporal/tests/batch_exports/conftest.py index 28ef7aee14c6c..98250309e33cc 100644 --- a/posthog/temporal/tests/batch_exports/conftest.py +++ b/posthog/temporal/tests/batch_exports/conftest.py @@ -1,5 +1,7 @@ +import psycopg import pytest import pytest_asyncio +from psycopg import sql @pytest.fixture @@ -39,4 +41,72 @@ async def truncate_events(clickhouse_client): This is useful if during the test setup we insert a lot of events we wish to clean-up. """ yield - await clickhouse_client.execute_query("TRUNCATE TABLE `sharded_events`") + await clickhouse_client.execute_query("TRUNCATE TABLE IF EXISTS `sharded_events`") + + +@pytest_asyncio.fixture +async def setup_postgres_test_db(postgres_config): + """Fixture to manage a database for Redshift export testing. + + Managing a test database involves the following steps: + 1. Creating a test database. + 2. Initializing a connection to that database. + 3. Creating a test schema. + 4. Yielding the connection to be used in tests. + 5. After tests, drop the test schema and any tables in it. + 6. Drop the test database. + """ + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), + (postgres_config["database"],), + ) + + if await cursor.fetchone() is None: + await cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + + await connection.close() + + # We need a new connection to connect to the database we just created. + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + dbname=postgres_config["database"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"])) + ) + + yield + + async with connection.cursor() as cursor: + await cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"]))) + + await connection.close() + + # We need a new connection to drop the database, as we cannot drop the current database. + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + + await connection.close() diff --git a/posthog/temporal/tests/batch_exports/test_batch_exports.py b/posthog/temporal/tests/batch_exports/test_batch_exports.py index 3b988307e2e91..097f13869d2eb 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_exports.py +++ b/posthog/temporal/tests/batch_exports/test_batch_exports.py @@ -1,38 +1,25 @@ import csv -import dataclasses import datetime as dt import io import json -import logging import operator -import random -import string -import uuid from random import randint -from unittest.mock import patch import pytest -from freezegun import freeze_time -from temporalio import activity, workflow -from posthog.clickhouse.log_entries import ( - KAFKA_LOG_ENTRIES, -) from posthog.temporal.tests.utils.datetimes import ( to_isoformat, ) from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.workflows.batch_exports import ( BatchExportTemporaryFile, - KafkaLoggingHandler, - get_batch_exports_logger, get_data_interval, get_results_iterator, get_rows_count, json_dumps_bytes, ) -pytestmark = [pytest.mark.django_db, pytest.mark.asyncio] +pytestmark = [pytest.mark.asyncio, pytest.mark.django_db] async def test_get_rows_count(clickhouse_client): @@ -540,104 +527,3 @@ def test_batch_export_temporary_file_write_records_to_tsv(records): assert be_file.bytes_since_last_reset == 0 assert be_file.records_total == len(records) assert be_file.records_since_last_reset == 0 - - -def test_kafka_logging_handler_produces_to_kafka(caplog): - """Test a mocked call to Kafka produce from the KafkaLoggingHandler.""" - logger_name = "test-logger" - logger = logging.getLogger(logger_name) - handler = KafkaLoggingHandler(topic=KAFKA_LOG_ENTRIES) - handler.setLevel(logging.DEBUG) - logger.addHandler(handler) - - team_id = random.randint(1, 10000) - batch_export_id = str(uuid.uuid4()) - run_id = str(uuid.uuid4()) - timestamp = "2023-09-21 00:01:01.000001" - - expected_tuples = [] - expected_kafka_produce_calls_kwargs = [] - - with patch("posthog.kafka_client.client._KafkaProducer.produce") as produce: - with caplog.at_level(logging.DEBUG): - with freeze_time(timestamp): - for level in (10, 20, 30, 40, 50): - random_message = "".join(random.choice(string.ascii_letters) for _ in range(30)) - - logger.log( - level, - random_message, - extra={ - "team_id": team_id, - "batch_export_id": batch_export_id, - "workflow_run_id": run_id, - }, - ) - - expected_tuples.append( - ( - logger_name, - level, - random_message, - ) - ) - data = { - "message": random_message, - "team_id": team_id, - "log_source": "batch_exports", - "log_source_id": batch_export_id, - "instance_id": run_id, - "timestamp": timestamp, - "level": logging.getLevelName(level), - } - expected_kafka_produce_calls_kwargs.append({"topic": KAFKA_LOG_ENTRIES, "data": data, "key": None}) - - assert caplog.record_tuples == expected_tuples - - kafka_produce_calls_kwargs = [call.kwargs for call in produce.call_args_list] - assert kafka_produce_calls_kwargs == expected_kafka_produce_calls_kwargs - - -@dataclasses.dataclass -class TestInputs: - team_id: int - data_interval_end: str | None = None - interval: str = "hour" - batch_export_id: str = "" - - -@dataclasses.dataclass -class TestInfo: - workflow_id: str - run_id: str - workflow_run_id: str - attempt: int - - -@pytest.mark.parametrize("context", [activity.__name__, workflow.__name__]) -def test_batch_export_logger_adapter(context, caplog): - """Test BatchExportLoggerAdapter sets the appropiate context variables.""" - team_id = random.randint(1, 10000) - inputs = TestInputs(team_id=team_id) - logger = get_batch_exports_logger(inputs=inputs) - - batch_export_id = str(uuid.uuid4()) - run_id = str(uuid.uuid4()) - attempt = random.randint(1, 10) - info = TestInfo( - workflow_id=f"{batch_export_id}-{dt.datetime.utcnow().isoformat()}", - run_id=run_id, - workflow_run_id=run_id, - attempt=attempt, - ) - - with patch("posthog.kafka_client.client._KafkaProducer.produce"): - with patch(context + ".info", return_value=info): - for level in (10, 20, 30, 40, 50): - logger.log(level, "test") - - records = caplog.get_records("call") - assert all(record.team_id == team_id for record in records) - assert all(record.batch_export_id == batch_export_id for record in records) - assert all(record.workflow_run_id == run_id for record in records) - assert all(record.attempt == attempt for record in records) diff --git a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py index 72646e2e993c4..d6696407c60e2 100644 --- a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py @@ -30,7 +30,13 @@ insert_into_bigquery_activity, ) -pytestmark = [pytest.mark.asyncio, pytest.mark.django_db] +SKIP_IF_MISSING_GOOGLE_APPLICATION_CREDENTIALS = pytest.mark.skipif( + "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ, + reason="Google credentials not set in environment", +) + +pytestmark = [SKIP_IF_MISSING_GOOGLE_APPLICATION_CREDENTIALS, pytest.mark.asyncio, pytest.mark.django_db] + TEST_TIME = dt.datetime.utcnow() @@ -108,9 +114,6 @@ def bigquery_config() -> dict[str, str]: "private_key_id": credentials["private_key_id"], "token_uri": credentials["token_uri"], "client_email": credentials["client_email"], - # Not part of the credentials. - # Hardcoded to test dataset. - "dataset_id": "BatchExports", } @@ -119,19 +122,30 @@ def bigquery_client() -> typing.Generator[bigquery.Client, None, None]: """Manage a bigquery.Client for testing.""" client = bigquery.Client() - try: - yield client - finally: - client.close() + yield client + + client.close() + + +@pytest.fixture +def bigquery_dataset(bigquery_config, bigquery_client) -> typing.Generator[bigquery.Dataset, None, None]: + """Manage a bigquery dataset for testing. + + We clean up the dataset after every test. Could be quite time expensive, but guarantees a clean slate. + """ + dataset_id = f"{bigquery_config['project_id']}.BatchExportsTest_{str(uuid4()).replace('-', '')}" + + dataset = bigquery.Dataset(dataset_id) + dataset = bigquery_client.create_dataset(dataset) + + yield dataset + + bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) -@pytest.mark.skipif( - "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ, - reason="Google credentials not set in environment", -) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( - clickhouse_client, activity_environment, bigquery_client, bigquery_config, exclude_events + clickhouse_client, activity_environment, bigquery_client, bigquery_config, exclude_events, bigquery_dataset ): """Test that the insert_into_bigquery_activity function inserts data into a BigQuery table. @@ -194,6 +208,7 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( insert_inputs = BigQueryInsertInputs( team_id=team_id, table_id=f"test_insert_activity_table_{team_id}", + dataset_id=bigquery_dataset.dataset_id, data_interval_start=data_interval_start.isoformat(), data_interval_end=data_interval_end.isoformat(), exclude_events=exclude_events, @@ -208,7 +223,7 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( assert_events_in_bigquery( client=bigquery_client, table_id=f"test_insert_activity_table_{team_id}", - dataset_id=bigquery_config["dataset_id"], + dataset_id=bigquery_dataset.dataset_id, events=events + events_with_no_properties, bq_ingested_timestamp=ingested_timestamp, exclude_events=exclude_events, @@ -221,12 +236,15 @@ def table_id(ateam, interval): @pytest_asyncio.fixture -async def bigquery_batch_export(ateam, table_id, bigquery_config, interval, exclude_events, temporal_client): +async def bigquery_batch_export( + ateam, table_id, bigquery_config, interval, exclude_events, temporal_client, bigquery_dataset +): destination_data = { "type": "BigQuery", "config": { **bigquery_config, "table_id": table_id, + "dataset_id": bigquery_dataset.dataset_id, "exclude_events": exclude_events, }, } @@ -249,15 +267,10 @@ async def bigquery_batch_export(ateam, table_id, bigquery_config, interval, excl await adelete_batch_export(batch_export, temporal_client) -@pytest.mark.skipif( - "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ, - reason="Google credentials not set in environment", -) @pytest.mark.parametrize("interval", ["hour", "day"]) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) async def test_bigquery_export_workflow( clickhouse_client, - bigquery_config, bigquery_client, bigquery_batch_export, interval, @@ -303,7 +316,7 @@ async def test_bigquery_export_workflow( inputs = BigQueryBatchExportInputs( team_id=ateam.pk, batch_export_id=str(bigquery_batch_export.id), - data_interval_end="2023-04-25 14:30:00.000000", + data_interval_end=data_interval_end.isoformat(), interval=interval, **bigquery_batch_export.destination.config, ) @@ -340,17 +353,13 @@ async def test_bigquery_export_workflow( assert_events_in_bigquery( client=bigquery_client, table_id=table_id, - dataset_id=bigquery_config["dataset_id"], + dataset_id=bigquery_batch_export.destination.config["dataset_id"], events=events, bq_ingested_timestamp=ingested_timestamp, exclude_events=exclude_events, ) -@pytest.mark.skipif( - "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ, - reason="Google credentials not set in environment", -) async def test_bigquery_export_workflow_handles_insert_activity_errors(ateam, bigquery_batch_export, interval): """Test that BigQuery Export Workflow can gracefully handle errors when inserting BigQuery data.""" data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") @@ -397,10 +406,6 @@ async def insert_into_bigquery_activity_mocked(_: BigQueryInsertInputs) -> str: assert run.latest_error == "ValueError: A useful error message" -@pytest.mark.skipif( - "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ, - reason="Google credentials not set in environment", -) async def test_bigquery_export_workflow_handles_cancellation(ateam, bigquery_batch_export, interval): """Test that BigQuery Export Workflow can gracefully handle cancellations when inserting BigQuery data.""" data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") @@ -439,6 +444,7 @@ async def never_finish_activity(_: BigQueryInsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, retry_policy=RetryPolicy(maximum_attempts=1), ) + await asyncio.sleep(5) await handle.cancel() diff --git a/posthog/temporal/tests/batch_exports/test_logger.py b/posthog/temporal/tests/batch_exports/test_logger.py new file mode 100644 index 0000000000000..7614ab721f340 --- /dev/null +++ b/posthog/temporal/tests/batch_exports/test_logger.py @@ -0,0 +1,408 @@ +import asyncio +import dataclasses +import datetime as dt +import json +import random +import time +import uuid + +import aiokafka +import freezegun +import pytest +import pytest_asyncio +import structlog +import temporalio.activity +import temporalio.testing +from django.conf import settings + +from posthog.clickhouse.client import sync_execute +from posthog.clickhouse.log_entries import ( + KAFKA_LOG_ENTRIES_TABLE_SQL, + LOG_ENTRIES_TABLE, + LOG_ENTRIES_TABLE_MV_SQL, + TRUNCATE_LOG_ENTRIES_TABLE_SQL, +) +from posthog.kafka_client.topics import KAFKA_LOG_ENTRIES +from posthog.temporal.workflows.logger import bind_batch_exports_logger, configure_logger + +pytestmark = pytest.mark.asyncio + + +class LogCapture: + """A test StructLog processor to capture logs.""" + + def __init__(self): + self.entries = [] + + def __call__(self, logger, method_name, event_dict): + """Append event_dict to entries and drop the log.""" + self.entries.append(event_dict) + raise structlog.DropEvent() + + +@pytest.fixture() +def log_capture(): + """Return a LogCapture processor for inspection in tests.""" + return LogCapture() + + +class QueueCapture(asyncio.Queue): + """A test asyncio.Queue that captures items that we put into it.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.entries = [] + + def put_nowait(self, item): + """Append item to entries and delegate to asyncio.Queue.""" + self.entries.append(item) + super().put_nowait(item) + + +@pytest_asyncio.fixture() +async def queue(): + """Return a QueueCapture queue for inspection in tests.""" + queue = QueueCapture(maxsize=-1) + + yield queue + + +class CaptureKafkaProducer: + """A test aiokafka.AIOKafkaProducer that captures calls to send_and_wait.""" + + def __init__(self, *args, **kwargs): + self.entries = [] + self._producer: None | aiokafka.AIOKafkaProducer = None + + @property + def producer(self) -> aiokafka.AIOKafkaProducer: + if self._producer is None: + self._producer = aiokafka.AIOKafkaProducer( + bootstrap_servers=settings.KAFKA_HOSTS + ["localhost:9092"], + security_protocol=settings.KAFKA_SECURITY_PROTOCOL or "PLAINTEXT", + acks="all", + request_timeout_ms=1000000, + api_version="2.5.0", + ) + return self._producer + + async def send(self, topic, value=None, key=None, partition=None, timestamp_ms=None, headers=None): + """Append an entry and delegate to aiokafka.AIOKafkaProducer.""" + + self.entries.append( + { + "topic": topic, + "value": value, + "key": key, + "partition": partition, + "timestamp_ms": timestamp_ms, + "headers": headers, + } + ) + return await self.producer.send(topic, value, key, partition, timestamp_ms, headers) + + async def start(self): + await self.producer.start() + + async def stop(self): + await self.producer.stop() + + async def flush(self): + await self.producer.flush() + + @property + def _closed(self): + return self.producer._closed + + +@pytest_asyncio.fixture(scope="function") +async def producer(event_loop): + """Yield a CaptureKafkaProducer to inspect entries captured. + + After usage, we ensure the producer was closed to avoid leaking/warnings. + """ + producer = CaptureKafkaProducer(bootstrap_servers=settings.KAFKA_HOSTS, loop=event_loop) + + yield producer + + if producer._closed is False: + await producer.stop() + + +@pytest_asyncio.fixture(autouse=True) +async def configure(log_capture, queue, producer): + """Configure StructLog logging for testing. + + The extra parameters configured for testing are: + * Add a LogCapture processor to capture logs. + * Set the queue and producer to capture messages sent. + * Do not cache logger to ensure each test starts clean. + """ + tasks = await configure_logger( + extra_processors=[log_capture], queue=queue, producer=producer, cache_logger_on_first_use=False + ) + yield tasks + + for task in tasks: + # Clean up logger tasks to avoid leaking/warnings. + task.cancel() + + await asyncio.wait(tasks) + + +async def test_batch_exports_logger_binds_context(log_capture): + """Test whether we can bind context variables.""" + logger = await bind_batch_exports_logger(team_id=1, destination="Somewhere") + + logger.info("Hi! This is an info log") + logger.error("Hi! This is an erro log") + + assert len(log_capture.entries) == 2 + + info_entry, error_entry = log_capture.entries + info_dict, error_dict = json.loads(info_entry), json.loads(error_entry) + assert info_dict["team_id"] == 1 + assert info_dict["destination"] == "Somewhere" + + assert error_dict["team_id"] == 1 + assert error_dict["destination"] == "Somewhere" + + +async def test_batch_exports_logger_formats_positional_args(log_capture): + """Test whether positional arguments are formatted in the message.""" + logger = await bind_batch_exports_logger(team_id=1, destination="Somewhere") + + logger.info("Hi! This is an %s log", "info") + logger.error("Hi! This is an %s log", "error") + + assert len(log_capture.entries) == 2 + + info_entry, error_entry = log_capture.entries + info_dict, error_dict = json.loads(info_entry), json.loads(error_entry) + assert info_dict["msg"] == "Hi! This is an info log" + assert error_dict["msg"] == "Hi! This is an error log" + + +@dataclasses.dataclass +class ActivityInfo: + """Provide our own Activity Info for testing.""" + + workflow_id: str + workflow_type: str + workflow_run_id: str + attempt: int + + +@pytest.fixture +def activity_environment(request): + """Return a testing temporal ActivityEnvironment.""" + env = temporalio.testing.ActivityEnvironment() + env.info = request.param + return env + + +BATCH_EXPORT_ID = str(uuid.uuid4()) + + +@pytest.mark.parametrize( + "activity_environment", + [ + ActivityInfo( + workflow_id=f"{BATCH_EXPORT_ID}-{dt.datetime.utcnow()}", + workflow_type="s3-export", + workflow_run_id=str(uuid.uuid4()), + attempt=random.randint(1, 10000), + ), + ActivityInfo( + workflow_id=f"{BATCH_EXPORT_ID}-Backfill-{dt.datetime.utcnow()}", + workflow_type="backfill-batch-export", + workflow_run_id=str(uuid.uuid4()), + attempt=random.randint(1, 10000), + ), + ], + indirect=True, +) +async def test_batch_exports_logger_binds_activity_context( + log_capture, + activity_environment, +): + """Test whether our logger binds variables from a Temporal Activity.""" + + @temporalio.activity.defn + async def log_activity(): + """A simple temporal activity that just logs.""" + logger = await bind_batch_exports_logger(team_id=1, destination="Somewhere") + + logger.info("Hi! This is an %s log from an activity", "info") + + await activity_environment.run(log_activity) + + assert len(log_capture.entries) == 1 + + info_dict = json.loads(log_capture.entries[0]) + assert info_dict["team_id"] == 1 + assert info_dict["destination"] == "Somewhere" + assert info_dict["workflow_id"] == activity_environment.info.workflow_id + assert info_dict["workflow_type"] == activity_environment.info.workflow_type + assert info_dict["log_source_id"] == BATCH_EXPORT_ID + assert info_dict["workflow_run_id"] == activity_environment.info.workflow_run_id + assert info_dict["attempt"] == activity_environment.info.attempt + + if activity_environment.info.workflow_type == "backfill-batch-export": + assert info_dict["log_source"] == "batch_exports_backfill" + else: + assert info_dict["log_source"] == "batch_exports" + + +@freezegun.freeze_time("2023-11-02 10:00:00.123123") +@pytest.mark.parametrize( + "activity_environment", + [ + ActivityInfo( + workflow_id=f"{BATCH_EXPORT_ID}-{dt.datetime.utcnow()}", + workflow_type="s3-export", + workflow_run_id=str(uuid.uuid4()), + attempt=random.randint(1, 10000), + ), + ActivityInfo( + workflow_id=f"{BATCH_EXPORT_ID}-Backfill-{dt.datetime.utcnow()}", + workflow_type="backfill-batch-export", + workflow_run_id=str(uuid.uuid4()), + attempt=random.randint(1, 10000), + ), + ], + indirect=True, +) +async def test_batch_exports_logger_puts_in_queue(activity_environment, queue): + """Test whether our logger puts entries into a queue for async processing.""" + + @temporalio.activity.defn + async def log_activity(): + """A simple temporal activity that just logs.""" + logger = await bind_batch_exports_logger(team_id=2, destination="Somewhere") + + logger.info("Hi! This is an %s log from an activity", "info") + + await activity_environment.run(log_activity) + + assert len(queue.entries) == 1 + message_dict = json.loads(queue.entries[0].decode("utf-8")) + + assert message_dict["instance_id"] == activity_environment.info.workflow_run_id + assert message_dict["level"] == "info" + + if activity_environment.info.workflow_type == "backfill-batch-export": + assert message_dict["log_source"] == "batch_exports_backfill" + else: + assert message_dict["log_source"] == "batch_exports" + + assert message_dict["log_source_id"] == BATCH_EXPORT_ID + assert message_dict["message"] == "Hi! This is an info log from an activity" + assert message_dict["team_id"] == 2 + assert message_dict["timestamp"] == "2023-11-02 10:00:00.123123" + + +@pytest.fixture +def log_entries_table(): + """Manage log_entries table for testing.""" + sync_execute(KAFKA_LOG_ENTRIES_TABLE_SQL()) + sync_execute(LOG_ENTRIES_TABLE_MV_SQL) + sync_execute(TRUNCATE_LOG_ENTRIES_TABLE_SQL) + + yield LOG_ENTRIES_TABLE + + sync_execute(f"DROP TABLE {LOG_ENTRIES_TABLE}_mv") + sync_execute(f"DROP TABLE kafka_{LOG_ENTRIES_TABLE}") + sync_execute(TRUNCATE_LOG_ENTRIES_TABLE_SQL) + + +@pytest.mark.django_db +@pytest.mark.parametrize( + "activity_environment", + [ + ActivityInfo( + workflow_id=f"{BATCH_EXPORT_ID}-{dt.datetime.utcnow()}", + workflow_type="s3-export", + workflow_run_id=str(uuid.uuid4()), + attempt=random.randint(1, 10000), + ), + ActivityInfo( + workflow_id=f"{BATCH_EXPORT_ID}-Backfill-{dt.datetime.utcnow()}", + workflow_type="backfill-batch-export", + workflow_run_id=str(uuid.uuid4()), + attempt=random.randint(1, 10000), + ), + ], + indirect=True, +) +async def test_batch_exports_logger_produces_to_kafka(activity_environment, producer, queue, log_entries_table): + """Test whether our logger produces messages to Kafka. + + We also check if those messages are ingested into ClickHouse. + """ + + @temporalio.activity.defn + async def log_activity(): + """A simple temporal activity that just logs.""" + logger = await bind_batch_exports_logger(team_id=3, destination="Somewhere") + + logger.info("Hi! This is an %s log from an activity", "info") + + with freezegun.freeze_time("2023-11-03 10:00:00.123123"): + await activity_environment.run(log_activity) + + assert len(queue.entries) == 1 + + await queue.join() + + if activity_environment.info.workflow_type == "backfill-batch-export": + expected_log_source = "batch_exports_backfill" + else: + expected_log_source = "batch_exports" + + expected_dict = { + "instance_id": activity_environment.info.workflow_run_id, + "level": "info", + "log_source": expected_log_source, + "log_source_id": BATCH_EXPORT_ID, + "message": "Hi! This is an info log from an activity", + "team_id": 3, + "timestamp": "2023-11-03 10:00:00.123123", + } + + assert len(producer.entries) == 1 + assert producer.entries[0] == { + "topic": KAFKA_LOG_ENTRIES, + "value": json.dumps(expected_dict).encode("utf-8"), + "key": None, + "partition": None, + "timestamp_ms": None, + "headers": None, + } + + results = sync_execute( + f"SELECT instance_id, level, log_source, log_source_id, message, team_id, timestamp FROM {log_entries_table}" + ) + + iterations = 0 + while not results: + # It may take a bit for CH to ingest. + time.sleep(1) + results = sync_execute( + f"SELECT instance_id, level, log_source, log_source_id, message, team_id, timestamp FROM {log_entries_table}" + ) + + iterations += 1 + if iterations > 10: + raise TimeoutError("Timedout waiting for logs") + + assert len(results) == 1 # type: ignore + + row = results[0] # type: ignore + assert row[0] == activity_environment.info.workflow_run_id + assert row[1] == "info" + assert row[2] == expected_log_source + assert row[3] == BATCH_EXPORT_ID + assert row[4] == "Hi! This is an info log from an activity" + assert row[5] == 3 + assert row[6].isoformat() == "2023-11-03T10:00:00.123123+00:00" diff --git a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py index 1e0643f4f1b9b..6a70e9f2eb74c 100644 --- a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py @@ -4,12 +4,12 @@ from random import randint from uuid import uuid4 -import psycopg2 +import psycopg import pytest import pytest_asyncio from django.conf import settings from django.test import override_settings -from psycopg2 import sql +from psycopg import sql from temporalio import activity from temporalio.client import WorkflowFailureError from temporalio.common import RetryPolicy @@ -29,18 +29,23 @@ insert_into_postgres_activity, ) -pytestmark = [pytest.mark.asyncio, pytest.mark.django_db] +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.django_db, +] -def assert_events_in_postgres(connection, schema, table_name, events, exclude_events: list[str] | None = None): +async def assert_events_in_postgres(connection, schema, table_name, events, exclude_events: list[str] | None = None): """Assert provided events written to a given Postgres table.""" inserted_events = [] - with connection.cursor() as cursor: - cursor.execute(sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format(sql.Identifier(schema, table_name))) + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format(sql.Identifier(schema, table_name)) + ) columns = [column.name for column in cursor.description] - for row in cursor.fetchall(): + for row in await cursor.fetchall(): event = dict(zip(columns, row)) event["timestamp"] = dt.datetime.fromisoformat(event["timestamp"].isoformat()) inserted_events.append(event) @@ -58,12 +63,12 @@ def assert_events_in_postgres(connection, schema, table_name, events, exclude_ev "distinct_id": event.get("distinct_id"), "elements": json.dumps(elements_chain), "event": event.get("event"), - "ip": properties.get("$ip", None) if properties else None, + "ip": properties.get("$ip", "") if properties else "", "properties": event.get("properties"), "set": properties.get("$set", None) if properties else None, "set_once": properties.get("$set_once", None) if properties else None, # Kept for backwards compatibility, but not exported anymore. - "site_url": None, + "site_url": "", # For compatibility with CH which doesn't parse timezone component, so we add it here assuming UTC. "timestamp": dt.datetime.fromisoformat(event.get("timestamp") + "+00:00"), "team_id": event.get("team_id"), @@ -91,75 +96,19 @@ def postgres_config(): } -@pytest.fixture -def setup_test_db(postgres_config): - connection = psycopg2.connect( - user=postgres_config["user"], - password=postgres_config["password"], - host=postgres_config["host"], - port=postgres_config["port"], - ) - connection.set_session(autocommit=True) - - with connection.cursor() as cursor: - cursor.execute( - sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), - (postgres_config["database"],), - ) - - if cursor.fetchone() is None: - cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"]))) - - connection.close() - - # We need a new connection to connect to the database we just created. - connection = psycopg2.connect( - user=postgres_config["user"], - password=postgres_config["password"], - host=postgres_config["host"], - port=postgres_config["port"], - database=postgres_config["database"], - ) - connection.set_session(autocommit=True) - - with connection.cursor() as cursor: - cursor.execute(sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"]))) - - yield - - with connection.cursor() as cursor: - cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"]))) - - connection.close() - - # We need a new connection to drop the database, as we cannot drop the current database. - connection = psycopg2.connect( - user=postgres_config["user"], - password=postgres_config["password"], - host=postgres_config["host"], - port=postgres_config["port"], - ) - connection.set_session(autocommit=True) - - with connection.cursor() as cursor: - cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"]))) - - connection.close() - - -@pytest.fixture -def postgres_connection(postgres_config, setup_test_db): - connection = psycopg2.connect( +@pytest_asyncio.fixture +async def postgres_connection(postgres_config, setup_postgres_test_db): + connection = await psycopg.AsyncConnection.connect( user=postgres_config["user"], password=postgres_config["password"], - database=postgres_config["database"], + dbname=postgres_config["database"], host=postgres_config["host"], port=postgres_config["port"], ) yield connection - connection.close() + await connection.close() @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @@ -238,7 +187,7 @@ async def test_insert_into_postgres_activity_inserts_data_into_postgres_table( with override_settings(BATCH_EXPORT_POSTGRES_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2): await activity_environment.run(insert_into_postgres_activity, insert_inputs) - assert_events_in_postgres( + await assert_events_in_postgres( connection=postgres_connection, schema=postgres_config["schema"], table_name="test_table", @@ -326,7 +275,7 @@ async def test_postgres_export_workflow( inputs = PostgresBatchExportInputs( team_id=ateam.pk, batch_export_id=str(postgres_batch_export.id), - data_interval_end="2023-04-25 14:30:00.000000", + data_interval_end=data_interval_end.isoformat(), interval=interval, **postgres_batch_export.destination.config, ) @@ -359,7 +308,7 @@ async def test_postgres_export_workflow( run = runs[0] assert run.status == "Completed" - assert_events_in_postgres( + await assert_events_in_postgres( postgres_connection, postgres_config["schema"], table_name, diff --git a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py index 2888484077371..176b487ff94a0 100644 --- a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py @@ -1,15 +1,16 @@ import datetime as dt import json import os +import warnings from random import randint from uuid import uuid4 -import psycopg2 +import psycopg import pytest import pytest_asyncio from django.conf import settings from django.test import override_settings -from psycopg2 import sql +from psycopg import sql from temporalio.common import RetryPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker @@ -33,24 +34,24 @@ "REDSHIFT_HOST", ) -SKIP_IF_MISSING_REQUIRED_ENV_VARS = pytest.mark.skipif( - any(env_var not in os.environ for env_var in REQUIRED_ENV_VARS), - reason="Redshift required env vars are not set", -) +MISSING_REQUIRED_ENV_VARS = any(env_var not in os.environ for env_var in REQUIRED_ENV_VARS) + -pytestmark = [SKIP_IF_MISSING_REQUIRED_ENV_VARS, pytest.mark.django_db, pytest.mark.asyncio] +pytestmark = [pytest.mark.django_db, pytest.mark.asyncio] -def assert_events_in_redshift(connection, schema, table_name, events, exclude_events: list[str] | None = None): +async def assert_events_in_redshift(connection, schema, table_name, events, exclude_events: list[str] | None = None): """Assert provided events written to a given Redshift table.""" inserted_events = [] - with connection.cursor() as cursor: - cursor.execute(sql.SQL("SELECT * FROM {} ORDER BY timestamp").format(sql.Identifier(schema, table_name))) + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format(sql.Identifier(schema, table_name)) + ) columns = [column.name for column in cursor.description] - for row in cursor.fetchall(): + for row in await cursor.fetchall(): event = dict(zip(columns, row)) event["timestamp"] = dt.datetime.fromisoformat(event["timestamp"].isoformat()) inserted_events.append(event) @@ -81,7 +82,7 @@ def assert_events_in_redshift(connection, schema, table_name, events, exclude_ev } expected_events.append(expected_event) - expected_events.sort(key=lambda x: x["timestamp"]) + expected_events.sort(key=lambda x: (x["event"], x["timestamp"])) assert len(inserted_events) == len(expected_events) # First check one event, the first one, so that we can get a nice diff if @@ -94,17 +95,26 @@ def assert_events_in_redshift(connection, schema, table_name, events, exclude_ev def redshift_config(): """Fixture to provide a default configuration for Redshift batch exports. - Reads required env vars to construct configuration. + Reads required env vars to construct configuration, but if not present + we default to local development PostgreSQL database, which should be mostly compatible. """ - user = os.environ["REDSHIFT_USER"] - password = os.environ["REDSHIFT_PASSWORD"] - host = os.environ["REDSHIFT_HOST"] - port = os.environ.get("REDSHIFT_PORT", "5439") + if MISSING_REQUIRED_ENV_VARS: + user = settings.PG_USER + password = settings.PG_PASSWORD + host = settings.PG_HOST + port = int(settings.PG_PORT) + warnings.warn("Missing required Redshift env vars. Running tests against local PG database.", stacklevel=1) + + else: + user = os.environ["REDSHIFT_USER"] + password = os.environ["REDSHIFT_PASSWORD"] + host = os.environ["REDSHIFT_HOST"] + port = os.environ.get("REDSHIFT_PORT", "5439") return { "user": user, "password": password, - "database": "exports_test_database", + "database": "dev", "schema": "exports_test_schema", "host": host, "port": int(port), @@ -112,89 +122,30 @@ def redshift_config(): @pytest.fixture -def setup_test_db(redshift_config): - """Fixture to manage a database for Redshift export testing. - - Managing a test database involves the following steps: - 1. Creating a test database. - 2. Initializing a connection to that database. - 3. Creating a test schema. - 4. Yielding the connection to be used in tests. - 5. After tests, drop the test schema and any tables in it. - 6. Drop the test database. - """ - connection = psycopg2.connect( - user=redshift_config["user"], - password=redshift_config["password"], - host=redshift_config["host"], - port=redshift_config["port"], - database="dev", - ) - connection.set_session(autocommit=True) +def postgres_config(redshift_config): + """We shadow this name so that setup_postgres_test_db works with Redshift.""" + return redshift_config - with connection.cursor() as cursor: - cursor.execute(sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), (redshift_config["database"],)) - - if cursor.fetchone() is None: - cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(redshift_config["database"]))) - - connection.close() - - # We need a new connection to connect to the database we just created. - connection = psycopg2.connect( - user=redshift_config["user"], - password=redshift_config["password"], - host=redshift_config["host"], - port=redshift_config["port"], - database=redshift_config["database"], - ) - connection.set_session(autocommit=True) - with connection.cursor() as cursor: - cursor.execute(sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(redshift_config["schema"]))) - - yield - - with connection.cursor() as cursor: - cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(redshift_config["schema"]))) - - connection.close() - - # We need a new connection to drop the database, as we cannot drop the current database. - connection = psycopg2.connect( - user=redshift_config["user"], - password=redshift_config["password"], - host=redshift_config["host"], - port=redshift_config["port"], - database="dev", - ) - connection.set_session(autocommit=True) - - with connection.cursor() as cursor: - cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(redshift_config["database"]))) - - connection.close() - - -@pytest.fixture -def psycopg2_connection(redshift_config, setup_test_db): +@pytest_asyncio.fixture +async def psycopg_connection(redshift_config, setup_postgres_test_db): """Fixture to manage a psycopg2 connection.""" - connection = psycopg2.connect( + connection = await psycopg.AsyncConnection.connect( user=redshift_config["user"], password=redshift_config["password"], - database=redshift_config["database"], + dbname=redshift_config["database"], host=redshift_config["host"], port=redshift_config["port"], ) yield connection - connection.close() + await connection.close() @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) async def test_insert_into_redshift_activity_inserts_data_into_redshift_table( - clickhouse_client, activity_environment, psycopg2_connection, redshift_config, exclude_events + clickhouse_client, activity_environment, psycopg_connection, redshift_config, exclude_events ): """Test that the insert_into_redshift_activity function inserts data into a Redshift table. @@ -265,8 +216,8 @@ async def test_insert_into_redshift_activity_inserts_data_into_redshift_table( await activity_environment.run(insert_into_redshift_activity, insert_inputs) - assert_events_in_redshift( - connection=psycopg2_connection, + await assert_events_in_redshift( + connection=psycopg_connection, schema=redshift_config["schema"], table_name="test_table", events=events + events_with_no_properties, @@ -308,11 +259,12 @@ async def redshift_batch_export(ateam, table_name, redshift_config, interval, ex async def test_redshift_export_workflow( clickhouse_client, redshift_config, - psycopg2_connection, + psycopg_connection, interval, redshift_batch_export, ateam, exclude_events, + table_name, ): """Test Redshift Export Workflow end-to-end. @@ -385,8 +337,8 @@ async def test_redshift_export_workflow( run = runs[0] assert run.status == "Completed" - assert_events_in_redshift( - psycopg2_connection, + await assert_events_in_redshift( + psycopg_connection, redshift_config["schema"], table_name, events=events, diff --git a/posthog/temporal/tests/conftest.py b/posthog/temporal/tests/conftest.py index 7c756ed44f717..4c480989db92b 100644 --- a/posthog/temporal/tests/conftest.py +++ b/posthog/temporal/tests/conftest.py @@ -63,7 +63,7 @@ def activity_environment(): return ActivityEnvironment() -@pytest.fixture(scope="module") +@pytest.fixture def clickhouse_client(): """Provide a ClickHouseClient to use in tests.""" client = ClickHouseClient( @@ -76,14 +76,7 @@ def clickhouse_client(): yield client -@pytest.fixture(scope="module") -def event_loop(): - loop = asyncio.get_event_loop() - yield loop - loop.close() - - -@pytest_asyncio.fixture(scope="module") +@pytest_asyncio.fixture async def temporal_client(): """Provide a temporalio.client.Client to use in tests.""" client = await connect( diff --git a/posthog/temporal/workflows/backfill_batch_export.py b/posthog/temporal/workflows/backfill_batch_export.py index 17f55ae1d8b54..724a745451d4f 100644 --- a/posthog/temporal/workflows/backfill_batch_export.py +++ b/posthog/temporal/workflows/backfill_batch_export.py @@ -20,9 +20,9 @@ CreateBatchExportBackfillInputs, UpdateBatchExportBackfillStatusInputs, create_batch_export_backfill_model, - get_batch_exports_logger, update_batch_export_backfill_model_status, ) +from posthog.temporal.workflows.logger import bind_batch_exports_logger class HeartbeatDetails(typing.NamedTuple): @@ -284,10 +284,9 @@ def parse_inputs(inputs: list[str]) -> BackfillBatchExportInputs: @temporalio.workflow.run async def run(self, inputs: BackfillBatchExportInputs) -> None: """Workflow implementation to backfill a BatchExport.""" - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id) logger.info( - "Starting Backfill for BatchExport %s: %s - %s", - inputs.batch_export_id, + "Starting Backfill for BatchExport: %s - %s", inputs.start_at, inputs.end_at, ) diff --git a/posthog/temporal/workflows/batch_exports.py b/posthog/temporal/workflows/batch_exports.py index 063069388801f..e35b91191bffa 100644 --- a/posthog/temporal/workflows/batch_exports.py +++ b/posthog/temporal/workflows/batch_exports.py @@ -4,9 +4,6 @@ import datetime as dt import gzip import json -import logging -import logging.handlers -import queue import tempfile import typing import uuid @@ -18,14 +15,12 @@ from temporalio.common import RetryPolicy from posthog.batch_exports.service import ( - BatchExportsInputsProtocol, create_batch_export_backfill, create_batch_export_run, update_batch_export_backfill_status, update_batch_export_run_status, ) -from posthog.kafka_client.client import KafkaProducer -from posthog.kafka_client.topics import KAFKA_LOG_ENTRIES +from posthog.temporal.workflows.logger import bind_batch_exports_logger SELECT_QUERY_TEMPLATE = Template( """ @@ -465,136 +460,6 @@ def reset(self): self.records_since_last_reset = 0 -class BatchExportLoggerAdapter(logging.LoggerAdapter): - """Adapter that adds batch export details to log records.""" - - def __init__( - self, - logger: logging.Logger, - extra=None, - ) -> None: - """Create the logger adapter.""" - super().__init__(logger, extra or {}) - - def process(self, msg: str, kwargs) -> tuple[typing.Any, collections.abc.MutableMapping[str, typing.Any]]: - """Override to add batch exports details.""" - workflow_id = None - workflow_run_id = None - attempt = None - - try: - activity_info = activity.info() - except RuntimeError: - pass - else: - workflow_run_id = activity_info.workflow_run_id - workflow_id = activity_info.workflow_id - attempt = activity_info.attempt - - try: - workflow_info = workflow.info() - except RuntimeError: - pass - else: - workflow_run_id = workflow_info.run_id - workflow_id = workflow_info.workflow_id - attempt = workflow_info.attempt - - if workflow_id is None or workflow_run_id is None or attempt is None: - return (None, {}) - - # This works because the WorkflowID is made up like f"{batch_export_id}-{data_interval_end}" - # Since {data_interval_date} is an iso formatted datetime string, it has two '-' to separate the - # date. Plus one more leaves us at the end of {batch_export_id}. - batch_export_id = workflow_id.rsplit("-", maxsplit=3)[0] - - extra = kwargs.get("extra", None) or {} - extra["workflow_id"] = workflow_id - extra["batch_export_id"] = batch_export_id - extra["workflow_run_id"] = workflow_run_id - extra["attempt"] = attempt - - if isinstance(self.extra, dict): - extra = extra | self.extra - kwargs["extra"] = extra - - return (msg, kwargs) - - @property - def base_logger(self) -> logging.Logger: - """Underlying logger usable for actions such as adding handlers/formatters.""" - return self.logger - - -class BatchExportsLogRecord(logging.LogRecord): - team_id: int - batch_export_id: str - workflow_run_id: str - attempt: int - - -class KafkaLoggingHandler(logging.Handler): - def __init__(self, topic, key=None): - super().__init__() - self.producer = KafkaProducer() - self.topic = topic - self.key = key - - def emit(self, record): - if record.name == "kafka": - return - - # This is a lie, but as long as this handler is used together - # with BatchExportLoggerAdapter we should be fine. - # This is definitely cheaper than a bunch if checks for attributes. - record = typing.cast(BatchExportsLogRecord, record) - - msg = self.format(record) - data = { - "instance_id": record.workflow_run_id, - "level": record.levelname, - "log_source": "batch_exports", - "log_source_id": record.batch_export_id, - "message": msg, - "team_id": record.team_id, - "timestamp": dt.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"), - } - - try: - future = self.producer.produce(topic=self.topic, data=data, key=self.key) - future.get(timeout=1) - except Exception as e: - logging.exception("Failed to produce log to Kafka topic %s", self.topic, exc_info=e) - - def close(self): - self.producer.close() - logging.Handler.close(self) - - -LOG_QUEUE: queue.Queue = queue.Queue(-1) -QUEUE_HANDLER = logging.handlers.QueueHandler(LOG_QUEUE) -QUEUE_HANDLER.setLevel(logging.DEBUG) - -KAFKA_HANDLER = KafkaLoggingHandler(topic=KAFKA_LOG_ENTRIES) -KAFKA_HANDLER.setLevel(logging.DEBUG) -QUEUE_LISTENER = logging.handlers.QueueListener(LOG_QUEUE, KAFKA_HANDLER) - -logger = logging.getLogger(__name__) -logger.addHandler(QUEUE_HANDLER) -logger.setLevel(logging.DEBUG) - - -def get_batch_exports_logger(inputs: BatchExportsInputsProtocol) -> BatchExportLoggerAdapter: - """Return a logger for BatchExports.""" - # Need a type comment as _thread is private. - if QUEUE_LISTENER._thread is None: # type: ignore - QUEUE_LISTENER.start() - - adapter = BatchExportLoggerAdapter(logger, {"team_id": inputs.team_id}) - - return adapter - - @dataclasses.dataclass class CreateBatchExportRunInputs: """Inputs to the create_export_run activity. @@ -620,9 +485,6 @@ async def create_export_run(inputs: CreateBatchExportRunInputs) -> str: Intended to be used in all export workflows, usually at the start, to create a model instance to represent them in our database. """ - logger = get_batch_exports_logger(inputs=inputs) - logger.info(f"Creating BatchExportRun model instance in team {inputs.team_id}.") - # 'sync_to_async' type hints are fixed in asgiref>=3.4.1 # But one of our dependencies is pinned to asgiref==3.3.2. # Remove these comments once we upgrade. @@ -633,8 +495,6 @@ async def create_export_run(inputs: CreateBatchExportRunInputs) -> str: status=inputs.status, ) - logger.info(f"Created BatchExportRun {run.id} in team {inputs.team_id}.") - return str(run.id) @@ -673,9 +533,6 @@ async def create_batch_export_backfill_model(inputs: CreateBatchExportBackfillIn Intended to be used in all export workflows, usually at the start, to create a model instance to represent them in our database. """ - logger = get_batch_exports_logger(inputs=inputs) - logger.info(f"Creating BatchExportBackfill model instance in team {inputs.team_id}.") - # 'sync_to_async' type hints are fixed in asgiref>=3.4.1 # But one of our dependencies is pinned to asgiref==3.3.2. # Remove these comments once we upgrade. @@ -687,8 +544,6 @@ async def create_batch_export_backfill_model(inputs: CreateBatchExportBackfillIn team_id=inputs.team_id, ) - logger.info(f"Created BatchExportBackfill {run.id} in team {inputs.team_id}.") - return str(run.id) @@ -734,7 +589,7 @@ async def execute_batch_export_insert_activity( initial_retry_interval_seconds: When retrying, seconds until the first retry. maximum_retry_interval_seconds: Maximum interval in seconds between retries. """ - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id) retry_policy = RetryPolicy( initial_interval=dt.timedelta(seconds=initial_retry_interval_seconds), diff --git a/posthog/temporal/workflows/bigquery_batch_export.py b/posthog/temporal/workflows/bigquery_batch_export.py index d9557d31bb07b..a743d665bb15b 100644 --- a/posthog/temporal/workflows/bigquery_batch_export.py +++ b/posthog/temporal/workflows/bigquery_batch_export.py @@ -17,12 +17,12 @@ UpdateBatchExportRunStatusInputs, create_export_run, execute_batch_export_insert_activity, - get_batch_exports_logger, get_data_interval, get_results_iterator, get_rows_count, ) from posthog.temporal.workflows.clickhouse import get_client +from posthog.temporal.workflows.logger import bind_batch_exports_logger def load_jsonl_file_to_bigquery_table(jsonl_file, table, table_schema, bigquery_client): @@ -98,9 +98,9 @@ def bigquery_client(inputs: BigQueryInsertInputs): @activity.defn async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): """Activity streams data from ClickHouse to BigQuery.""" - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="BigQuery") logger.info( - "Running BigQuery export batch %s - %s", + "Exporting batch %s - %s", inputs.data_interval_start, inputs.data_interval_end, ) @@ -126,7 +126,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): ) return - logger.info("BatchExporting %s rows to BigQuery", count) + logger.info("BatchExporting %s rows", count) results_iterator = get_results_iterator( client=client, @@ -174,7 +174,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): if jsonl_file.tell() > settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES: logger.info( - "Copying %s records of size %s bytes to BigQuery", + "Copying %s records of size %s bytes", jsonl_file.records_since_last_reset, jsonl_file.bytes_since_last_reset, ) @@ -188,7 +188,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): if jsonl_file.tell() > 0: logger.info( - "Copying %s records of size %s bytes to BigQuery", + "Copying %s records of size %s bytes", jsonl_file.records_since_last_reset, jsonl_file.bytes_since_last_reset, ) @@ -214,10 +214,10 @@ def parse_inputs(inputs: list[str]) -> BigQueryBatchExportInputs: @workflow.run async def run(self, inputs: BigQueryBatchExportInputs): """Workflow implementation to export data to BigQuery.""" - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="BigQuery") data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) logger.info( - "Starting BigQuery export batch %s - %s", + "Starting batch export %s - %s", data_interval_start, data_interval_end, ) diff --git a/posthog/temporal/workflows/logger.py b/posthog/temporal/workflows/logger.py new file mode 100644 index 0000000000000..d70efbe9cff20 --- /dev/null +++ b/posthog/temporal/workflows/logger.py @@ -0,0 +1,272 @@ +import asyncio +import json +import logging + +import aiokafka +import structlog +import temporalio.activity +import temporalio.workflow +from django.conf import settings +from structlog.processors import EventRenamer +from structlog.typing import FilteringBoundLogger + +from posthog.kafka_client.topics import KAFKA_LOG_ENTRIES + + +async def bind_batch_exports_logger(team_id: int, destination: str | None = None) -> FilteringBoundLogger: + """Return a bound logger for BatchExports.""" + if not structlog.is_configured(): + await configure_logger() + + logger = structlog.get_logger() + + return logger.new(team_id=team_id, destination=destination) + + +async def configure_logger( + logger_factory=structlog.PrintLoggerFactory, + extra_processors: list[structlog.types.Processor] | None = None, + queue: asyncio.Queue | None = None, + producer: aiokafka.AIOKafkaProducer | None = None, + cache_logger_on_first_use: bool = True, +) -> tuple: + """Configure a StructLog logger for batch exports. + + Configuring the logger involves: + * Setting up processors. + * Spawning a task to listen for Kafka logs. + * Spawning a task to shutdown gracefully on worker shutdown. + + Args: + logger_factory: Optionally, override the logger_factory. + extra_processors: Optionally, add any processors at the end of the chain. + queue: Optionally, bring your own log queue. + producer: Optionally, bring your own Kafka producer. + cache_logger_on_first_use: Set whether to cache logger for performance. + Should always be True except in tests. + """ + log_queue = queue if queue is not None else asyncio.Queue(maxsize=-1) + put_in_queue = PutInBatchExportsLogQueueProcessor(log_queue) + + base_processors: list[structlog.types.Processor] = [ + structlog.processors.add_log_level, + structlog.processors.format_exc_info, + structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M:%S.%f", utc=True), + structlog.stdlib.PositionalArgumentsFormatter(), + add_batch_export_context, + put_in_queue, + EventRenamer("msg"), + structlog.processors.JSONRenderer(), + ] + extra_processors_to_add = extra_processors if extra_processors is not None else [] + + structlog.configure( + processors=base_processors + extra_processors_to_add, + logger_factory=logger_factory(), + cache_logger_on_first_use=cache_logger_on_first_use, + ) + listen_task = asyncio.create_task( + KafkaLogProducerFromQueue(queue=log_queue, topic=KAFKA_LOG_ENTRIES, producer=producer).listen() + ) + + async def worker_shutdown_handler(): + """Gracefully handle a Temporal Worker shutting down. + + Graceful handling means: + * Waiting until the queue is fully processed to avoid missing log messages. + * Cancel task listening on queue. + """ + await temporalio.activity.wait_for_worker_shutdown() + + listen_task.cancel() + + await asyncio.wait([listen_task]) + + worker_shutdown_handler_task = asyncio.create_task(worker_shutdown_handler()) + + return (listen_task, worker_shutdown_handler_task) + + +class PutInBatchExportsLogQueueProcessor: + """A StructLog processor that puts event_dict into a queue. + + We format event_dict as a message to be sent to Kafka by a queue listener. + """ + + def __init__(self, queue: asyncio.Queue): + self.queue = queue + + def __call__( + self, logger: logging.Logger, method_name: str, event_dict: structlog.types.EventDict + ) -> structlog.types.EventDict: + """Put a message into the queue, if we have all the necessary details. + + Always return event_dict so that processors that come later in the chain can do + their own thing. + """ + try: + message_dict = { + "instance_id": event_dict["workflow_run_id"], + "level": event_dict["level"], + "log_source": event_dict["log_source"], + "log_source_id": event_dict["log_source_id"], + "message": event_dict["event"], + "team_id": event_dict["team_id"], + "timestamp": event_dict["timestamp"], + } + except KeyError: + # We don't have the required keys to ingest this log. + # This could be because we are running outside an Activity/Workflow context. + return event_dict + + self.queue.put_nowait(json.dumps(message_dict).encode("utf-8")) + + return event_dict + + +def add_batch_export_context(logger: logging.Logger, method_name: str, event_dict: structlog.types.EventDict): + """A StructLog processor to populate event dict with batch export context variables. + + More specifically, the batch export context variables are coming from Temporal: + * workflow_run_id: The ID of the Temporal Workflow Execution running the batch export. + * workflow_id: The ID of the Temporal Workflow running the batch export. + * attempt: The current attempt number of the Temporal Workflow. + * log_source_id: The batch export ID. + * log_source: Either "batch_exports" or "batch_exports_backfill". + + We attempt to fetch the context from the activity information, and then from the workflow + information. If both are undefined, nothing is populated. When running this processor in + an activity or a workflow, at least one will be defined. + """ + activity_info = attempt_to_fetch_activity_info() + workflow_info = attempt_to_fetch_workflow_info() + + info = activity_info or workflow_info + + if info is None: + return event_dict + + workflow_id, workflow_type, workflow_run_id, attempt = info + + if workflow_type == "backfill-batch-export": + # This works because the WorkflowID is made up like f"{batch_export_id}-Backfill-{data_interval_end}" + log_source_id = workflow_id.split("-Backfill")[0] + log_source = "batch_exports_backfill" + else: + # This works because the WorkflowID is made up like f"{batch_export_id}-{data_interval_end}" + # Since 'data_interval_end' is an iso formatted datetime string, it has two '-' to separate the + # date. Plus one more leaves us at the end of right at the end of 'batch_export_id'. + log_source_id = workflow_id.rsplit("-", maxsplit=3)[0] + log_source = "batch_exports" + + event_dict["workflow_id"] = workflow_id + event_dict["workflow_type"] = workflow_type + event_dict["log_source_id"] = log_source_id + event_dict["log_source"] = log_source + event_dict["workflow_run_id"] = workflow_run_id + event_dict["attempt"] = attempt + + return event_dict + + +Info = tuple[str, str, str, int] + + +def attempt_to_fetch_activity_info() -> Info | None: + """Fetch Activity information from Temporal. + + Returns: + None if calling outside an Activity, else the relevant Info. + """ + try: + activity_info = temporalio.activity.info() + except RuntimeError: + return None + else: + workflow_id = activity_info.workflow_id + workflow_type = activity_info.workflow_type + workflow_run_id = activity_info.workflow_run_id + attempt = activity_info.attempt + + return (workflow_id, workflow_type, workflow_run_id, attempt) + + +def attempt_to_fetch_workflow_info() -> Info | None: + """Fetch Workflow information from Temporal. + + Returns: + None if calling outside a Workflow, else the relevant Info. + """ + try: + workflow_info = temporalio.workflow.info() + except RuntimeError: + return None + else: + workflow_id = workflow_info.workflow_id + workflow_type = workflow_info.workflow_type + workflow_run_id = workflow_info.run_id + attempt = workflow_info.attempt + + return (workflow_id, workflow_type, workflow_run_id, attempt) + + +class KafkaLogProducerFromQueue: + """Produce log messages to Kafka by getting them from a queue. + + This KafkaLogProducerFromQueue was designed to ingest logs into the ClickHouse log_entries table. + For this reason, the messages we produce to Kafka are serialized as JSON in the schema expected by + the log_entries table. Eventually, we could de-couple this producer from the table schema, but + schema changes are rare in ClickHouse, and for now we are only using this for logs, so the tight + coupling is preferred over the extra complexity of de-coupling this producer. + + Attributes: + queue: The queue we are listening to get log event_dicts to serialize and produce. + topic: The topic to produce to. This should be left to the default KAFKA_LOG_ENTRIES. + key: The key for Kafka partitioning. Default to None for random partition. + producer: Optionally, bring your own aiokafka.AIOKafkaProducer. This is mostly here for testing. + """ + + def __init__( + self, + queue: asyncio.Queue, + topic: str = KAFKA_LOG_ENTRIES, + key: str | None = None, + producer: aiokafka.AIOKafkaProducer | None = None, + ): + self.queue = queue + self.topic = topic + self.key = key + self.producer = ( + producer + if producer is not None + else aiokafka.AIOKafkaProducer( + bootstrap_servers=settings.KAFKA_HOSTS, + security_protocol=settings.KAFKA_SECURITY_PROTOCOL or "PLAINTEXT", + acks="all", + api_version="2.5.0", + ) + ) + + async def listen(self): + """Listen to messages in queue and produce them to Kafka as they come. + + This is designed to be ran as an asyncio.Task, as it will wait forever for the queue + to have messages. + """ + await self.producer.start() + try: + while True: + msg = await self.queue.get() + await self.produce(msg) + + finally: + await self.producer.flush() + await self.producer.stop() + + async def produce(self, msg): + fut = await self.producer.send(self.topic, msg, key=self.key) + fut.add_done_callback(self.mark_queue_done) + await fut + + def mark_queue_done(self, _=None): + self.queue.task_done() diff --git a/posthog/temporal/workflows/postgres_batch_export.py b/posthog/temporal/workflows/postgres_batch_export.py index 8b66cfb0abb2c..eb7655f8e6f5a 100644 --- a/posthog/temporal/workflows/postgres_batch_export.py +++ b/posthog/temporal/workflows/postgres_batch_export.py @@ -2,12 +2,12 @@ import contextlib import datetime as dt import json +import typing from dataclasses import dataclass -import psycopg2 -import psycopg2.extensions +import psycopg from django.conf import settings -from psycopg2 import sql +from psycopg import sql from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -19,21 +19,21 @@ UpdateBatchExportRunStatusInputs, create_export_run, execute_batch_export_insert_activity, - get_batch_exports_logger, get_data_interval, get_results_iterator, get_rows_count, ) from posthog.temporal.workflows.clickhouse import get_client +from posthog.temporal.workflows.logger import bind_batch_exports_logger -@contextlib.contextmanager -def postgres_connection(inputs) -> collections.abc.Iterator[psycopg2.extensions.connection]: +@contextlib.asynccontextmanager +async def postgres_connection(inputs) -> typing.AsyncIterator[psycopg.AsyncConnection]: """Manage a Postgres connection.""" - connection = psycopg2.connect( + connection = await psycopg.AsyncConnection.connect( user=inputs.user, password=inputs.password, - database=inputs.database, + dbname=inputs.database, host=inputs.host, port=inputs.port, # The 'hasSelfSignedCert' parameter in the postgres-plugin was provided mainly @@ -46,17 +46,17 @@ def postgres_connection(inputs) -> collections.abc.Iterator[psycopg2.extensions. try: yield connection except Exception: - connection.rollback() + await connection.rollback() raise else: - connection.commit() + await connection.commit() finally: - connection.close() + await connection.close() -def copy_tsv_to_postgres( +async def copy_tsv_to_postgres( tsv_file, - postgres_connection: psycopg2.extensions.connection, + postgres_connection: psycopg.AsyncConnection, schema: str, table_name: str, schema_columns: list[str], @@ -65,35 +65,37 @@ def copy_tsv_to_postgres( Arguments: tsv_file: A file-like object to interpret as TSV to copy its contents. - postgres_connection: A connection to Postgres as setup by psycopg2. + postgres_connection: A connection to Postgres as setup by psycopg. schema: An existing schema where to create the table. table_name: The name of the table to create. schema_columns: A list of column names. """ tsv_file.seek(0) - with postgres_connection.cursor() as cursor: + async with postgres_connection.cursor() as cursor: if schema: - cursor.execute(sql.SQL("SET search_path TO {schema}").format(schema=sql.Identifier(schema))) - cursor.copy_from( - tsv_file, - table_name, - null="", - columns=schema_columns, - ) + await cursor.execute(sql.SQL("SET search_path TO {schema}").format(schema=sql.Identifier(schema))) + async with cursor.copy( + sql.SQL("COPY {table_name} ({fields}) FROM STDIN WITH DELIMITER AS '\t'").format( + table_name=sql.Identifier(table_name), + fields=sql.SQL(",").join((sql.Identifier(column) for column in schema_columns)), + ) + ) as copy: + while data := tsv_file.read(): + await copy.write(data) Field = tuple[str, str] Fields = collections.abc.Iterable[Field] -def create_table_in_postgres( - postgres_connection: psycopg2.extensions.connection, schema: str | None, table_name: str, fields: Fields +async def create_table_in_postgres( + postgres_connection: psycopg.AsyncConnection, schema: str | None, table_name: str, fields: Fields ) -> None: """Create a table in a Postgres database if it doesn't exist already. Arguments: - postgres_connection: A connection to Postgres as setup by psycopg2. + postgres_connection: A connection to Postgres as setup by psycopg. schema: An existing schema where to create the table. table_name: The name of the table to create. fields: An iterable of (name, type) tuples representing the fields of the table. @@ -103,8 +105,8 @@ def create_table_in_postgres( else: table_identifier = sql.Identifier(table_name) - with postgres_connection.cursor() as cursor: - cursor.execute( + async with postgres_connection.cursor() as cursor: + await cursor.execute( sql.SQL( """ CREATE TABLE IF NOT EXISTS {table} ( @@ -114,7 +116,13 @@ def create_table_in_postgres( ).format( table=table_identifier, fields=sql.SQL(",").join( - sql.SQL("{field} {type}").format(field=sql.Identifier(field), type=sql.SQL(field_type)) + # typing.LiteralString is not available in Python 3.10. + # So, we ignore it for now. + # This is safe as we are hardcoding the type values anyways. + sql.SQL("{field} {type}").format( + field=sql.Identifier(field), + type=sql.SQL(field_type), + ) for field, field_type in fields ), ) @@ -143,7 +151,7 @@ class PostgresInsertInputs: @activity.defn async def insert_into_postgres_activity(inputs: PostgresInsertInputs): """Activity streams data from ClickHouse to Postgres.""" - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="PostgreSQL") logger.info( "Running Postgres export batch %s - %s", inputs.data_interval_start, @@ -181,8 +189,8 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs): exclude_events=inputs.exclude_events, include_events=inputs.include_events, ) - with postgres_connection(inputs) as connection: - create_table_in_postgres( + async with postgres_connection(inputs) as connection: + await create_table_in_postgres( connection, schema=inputs.schema, table_name=inputs.table_name, @@ -217,11 +225,10 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs): json_columns = ("properties", "elements", "set", "set_once") with BatchExportTemporaryFile() as pg_file: - with postgres_connection(inputs) as connection: + async with postgres_connection(inputs) as connection: for result in results_iterator: row = { - key: json.dumps(result[key]) if key in json_columns and result[key] is not None else result[key] - for key in schema_columns + key: json.dumps(result[key]) if key in json_columns else result[key] for key in schema_columns } pg_file.write_records_to_tsv([row], fieldnames=schema_columns) @@ -231,7 +238,7 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs): pg_file.records_since_last_reset, pg_file.bytes_since_last_reset, ) - copy_tsv_to_postgres( + await copy_tsv_to_postgres( pg_file, connection, inputs.schema, @@ -246,7 +253,7 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs): pg_file.records_since_last_reset, pg_file.bytes_since_last_reset, ) - copy_tsv_to_postgres( + await copy_tsv_to_postgres( pg_file, connection, inputs.schema, @@ -274,7 +281,7 @@ def parse_inputs(inputs: list[str]) -> PostgresBatchExportInputs: @workflow.run async def run(self, inputs: PostgresBatchExportInputs): """Workflow implementation to export data to Postgres.""" - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="PostgreSQL") data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) logger.info( "Starting Postgres export batch %s - %s", diff --git a/posthog/temporal/workflows/redshift_batch_export.py b/posthog/temporal/workflows/redshift_batch_export.py index 74c1fb52662cc..cbce2304a1727 100644 --- a/posthog/temporal/workflows/redshift_batch_export.py +++ b/posthog/temporal/workflows/redshift_batch_export.py @@ -1,13 +1,13 @@ import collections.abc +import contextlib import datetime as dt +import itertools import json import typing from dataclasses import dataclass -import psycopg2 -import psycopg2.extensions -import psycopg2.extras -from psycopg2 import sql +import psycopg +from psycopg import sql from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -18,12 +18,12 @@ UpdateBatchExportRunStatusInputs, create_export_run, execute_batch_export_insert_activity, - get_batch_exports_logger, get_data_interval, get_results_iterator, get_rows_count, ) from posthog.temporal.workflows.clickhouse import get_client +from posthog.temporal.workflows.logger import bind_batch_exports_logger from posthog.temporal.workflows.postgres_batch_export import ( PostgresInsertInputs, create_table_in_postgres, @@ -31,9 +31,9 @@ ) -def insert_records_to_redshift( +async def insert_records_to_redshift( records: collections.abc.Iterator[dict[str, typing.Any]], - redshift_connection: psycopg2.extensions.connection, + redshift_connection: psycopg.AsyncConnection, schema: str, table: str, batch_size: int = 100, @@ -57,29 +57,53 @@ def insert_records_to_redshift( make us go OOM or exceed Redshift's SQL statement size limit (16MB). Setting this too low can significantly affect performance due to Redshift's poor handling of INSERTs. """ - batch = [next(records)] + first_record = next(records) + columns = first_record.keys() - columns = batch[0].keys() + pre_query = sql.SQL("INSERT INTO {table} ({fields}) VALUES").format( + table=sql.Identifier(schema, table), + fields=sql.SQL(", ").join(map(sql.Identifier, columns)), + ) + template = sql.SQL("({})").format(sql.SQL(", ").join(map(sql.Placeholder, columns))) - with redshift_connection.cursor() as cursor: - query = sql.SQL("INSERT INTO {table} ({fields}) VALUES {placeholder}").format( - table=sql.Identifier(schema, table), - fields=sql.SQL(", ").join(map(sql.Identifier, columns)), - placeholder=sql.Placeholder(), - ) - template = sql.SQL("({})").format(sql.SQL(", ").join(map(sql.Placeholder, columns))) + redshift_connection.cursor_factory = psycopg.AsyncClientCursor - for record in records: - batch.append(record) + async with async_client_cursor_from_connection(redshift_connection) as cursor: + batch = [pre_query.as_string(cursor).encode("utf-8")] + + for record in itertools.chain([first_record], records): + batch.append(cursor.mogrify(template, record).encode("utf-8")) if len(batch) < batch_size: + batch.append(b",") continue - psycopg2.extras.execute_values(cursor, query, batch, template) - batch = [] + await cursor.execute(b"".join(batch)) + batch = [pre_query.as_string(cursor).encode("utf-8")] if len(batch) > 0: - psycopg2.extras.execute_values(cursor, query, batch, template) + await cursor.execute(b"".join(batch[:-1])) + + +@contextlib.asynccontextmanager +async def async_client_cursor_from_connection( + psycopg_connection: psycopg.AsyncConnection, +) -> typing.AsyncIterator[psycopg.AsyncClientCursor]: + """Yield a AsyncClientCursor from a psycopg.AsyncConnection. + + Keeps track of the current cursor_factory to set it after we are done. + """ + current_factory = psycopg_connection.cursor_factory + psycopg_connection.cursor_factory = psycopg.AsyncClientCursor + + try: + async with psycopg_connection.cursor() as cursor: + # Not a fan of typing.cast, but we know this is an psycopg.AsyncClientCursor + # as we have just set cursor_factory. + cursor = typing.cast(psycopg.AsyncClientCursor, cursor) + yield cursor + finally: + psycopg_connection.cursor_factory = current_factory @dataclass @@ -110,9 +134,9 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs): the Redshift-specific properties_data_type to indicate the type of JSON-like fields. """ - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="Redshift") logger.info( - "Running Postgres export batch %s - %s", + "Exporting batch %s - %s", inputs.data_interval_start, inputs.data_interval_end, ) @@ -138,7 +162,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs): ) return - logger.info("BatchExporting %s rows to Postgres", count) + logger.info("BatchExporting %s rows", count) results_iterator = get_results_iterator( client=client, @@ -150,8 +174,8 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs): ) properties_type = "VARCHAR(65535)" if inputs.properties_data_type == "varchar" else "SUPER" - with postgres_connection(inputs) as connection: - create_table_in_postgres( + async with postgres_connection(inputs) as connection: + await create_table_in_postgres( connection, schema=inputs.schema, table_name=inputs.table_name, @@ -192,8 +216,8 @@ def map_to_record(row: dict) -> dict: for key in schema_columns } - with postgres_connection(inputs) as connection: - insert_records_to_redshift( + async with postgres_connection(inputs) as connection: + await insert_records_to_redshift( (map_to_record(result) for result in results_iterator), connection, inputs.schema, inputs.table_name ) @@ -217,7 +241,7 @@ def parse_inputs(inputs: list[str]) -> RedshiftBatchExportInputs: @workflow.run async def run(self, inputs: RedshiftBatchExportInputs): """Workflow implementation to export data to Redshift.""" - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="Redshift") data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) logger.info("Starting Redshift export batch %s - %s", data_interval_start, data_interval_end) diff --git a/posthog/temporal/workflows/s3_batch_export.py b/posthog/temporal/workflows/s3_batch_export.py index 6a81aeeb93a77..8a5a851d28b1c 100644 --- a/posthog/temporal/workflows/s3_batch_export.py +++ b/posthog/temporal/workflows/s3_batch_export.py @@ -20,12 +20,12 @@ UpdateBatchExportRunStatusInputs, create_export_run, execute_batch_export_insert_activity, - get_batch_exports_logger, get_data_interval, get_results_iterator, get_rows_count, ) from posthog.temporal.workflows.clickhouse import get_client +from posthog.temporal.workflows.logger import bind_batch_exports_logger def get_allowed_template_variables(inputs) -> dict[str, str]: @@ -303,7 +303,7 @@ class S3InsertInputs: async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str]: """Initialize a S3MultiPartUpload and resume it from a hearbeat state if available.""" - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="S3") key = get_s3_key(inputs) s3_upload = S3MultiPartUpload( @@ -323,19 +323,22 @@ async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tupl except IndexError: # This is the error we expect when no details as the sequence will be empty. interval_start = inputs.data_interval_start - logger.info( - f"Did not receive details from previous activity Excecution. Export will start from the beginning: {interval_start}" + logger.debug( + "Did not receive details from previous activity Excecution. Export will start from the beginning %s", + interval_start, ) except Exception: # We still start from the beginning, but we make a point to log unexpected errors. # Ideally, any new exceptions should be added to the previous block after the first time and we will never land here. interval_start = inputs.data_interval_start logger.warning( - f"Did not receive details from previous activity Excecution due to an unexpected error. Export will start from the beginning: {interval_start}", + "Did not receive details from previous activity Excecution due to an unexpected error. Export will start from the beginning %s", + interval_start, ) else: logger.info( - f"Received details from previous activity. Export will attempt to resume from: {interval_start}", + "Received details from previous activity. Export will attempt to resume from %s", + interval_start, ) s3_upload.continue_from_state(upload_state) @@ -344,7 +347,8 @@ async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tupl interval_start = inputs.data_interval_start logger.info( - f"Export will start from the beginning as we are using brotli compression: {interval_start}", + f"Export will start from the beginning as we are using brotli compression: %s", + interval_start, ) await s3_upload.abort() @@ -362,9 +366,9 @@ async def insert_into_s3_activity(inputs: S3InsertInputs): runs, timing out after say 30 seconds or something and upload multiple files. """ - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="S3") logger.info( - "Running S3 export batch %s - %s", + "Exporting batch %s - %s", inputs.data_interval_start, inputs.data_interval_end, ) @@ -443,7 +447,7 @@ async def worker_shutdown_handler(): if local_results_file.tell() > settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES: logger.info( - "Uploading part %s containing %s records with size %s bytes to S3", + "Uploading part %s containing %s records with size %s bytes", s3_upload.part_number + 1, local_results_file.records_since_last_reset, local_results_file.bytes_since_last_reset, @@ -458,7 +462,7 @@ async def worker_shutdown_handler(): if local_results_file.tell() > 0 and result is not None: logger.info( - "Uploading last part %s containing %s records with size %s bytes to S3", + "Uploading last part %s containing %s records with size %s bytes", s3_upload.part_number + 1, local_results_file.records_since_last_reset, local_results_file.bytes_since_last_reset, @@ -490,9 +494,9 @@ def parse_inputs(inputs: list[str]) -> S3BatchExportInputs: @workflow.run async def run(self, inputs: S3BatchExportInputs): """Workflow implementation to export data to S3 bucket.""" - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="S3") data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) - logger.info("Starting S3 export batch %s - %s", data_interval_start, data_interval_end) + logger.info("Starting batch export %s - %s", data_interval_start, data_interval_end) create_export_run_inputs = CreateBatchExportRunInputs( team_id=inputs.team_id, diff --git a/posthog/temporal/workflows/snowflake_batch_export.py b/posthog/temporal/workflows/snowflake_batch_export.py index ec556e527192a..026e9a512c016 100644 --- a/posthog/temporal/workflows/snowflake_batch_export.py +++ b/posthog/temporal/workflows/snowflake_batch_export.py @@ -16,12 +16,12 @@ UpdateBatchExportRunStatusInputs, create_export_run, execute_batch_export_insert_activity, - get_batch_exports_logger, get_data_interval, get_results_iterator, get_rows_count, ) from posthog.temporal.workflows.clickhouse import get_client +from posthog.temporal.workflows.logger import bind_batch_exports_logger class SnowflakeFileNotUploadedError(Exception): @@ -98,9 +98,9 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs): TODO: We're using JSON here, it's not the most efficient way to do this. """ - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="Snowflake") logger.info( - "Running Snowflake export batch %s - %s", + "Exporting batch %s - %s", inputs.data_interval_start, inputs.data_interval_end, ) @@ -126,7 +126,7 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs): ) return - logger.info("BatchExporting %s rows to Snowflake", count) + logger.info("BatchExporting %s rows", count) conn = snowflake.connector.connect( user=inputs.user, @@ -294,10 +294,10 @@ def parse_inputs(inputs: list[str]) -> SnowflakeBatchExportInputs: @workflow.run async def run(self, inputs: SnowflakeBatchExportInputs): """Workflow implementation to export data to Snowflake table.""" - logger = get_batch_exports_logger(inputs=inputs) + logger = await bind_batch_exports_logger(team_id=inputs.team_id, destination="Snowflake") data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) logger.info( - "Starting Snowflake export batch %s - %s", + "Starting batch export %s - %s", data_interval_start, data_interval_end, ) diff --git a/requirements.in b/requirements.in index d5c9e9449a32b..cad3efdc6e393 100644 --- a/requirements.in +++ b/requirements.in @@ -6,6 +6,7 @@ # aiohttp>=3.8.4 aioboto3==11.1 +aiokafka>=0.8 antlr4-python3-runtime==4.13.1 amqp==5.1.1 boto3==1.26.76 @@ -58,6 +59,7 @@ Pillow==9.2.0 posthoganalytics==3.0.1 prance==0.22.2.22.0 psycopg2-binary==2.9.7 +psycopg==3.1 pyarrow==12.0.1 pydantic==2.3.0 pyjwt==2.4.0 @@ -76,6 +78,7 @@ snowflake-connector-python==3.0.4 social-auth-app-django==5.0.0 social-auth-core==4.3.0 statshog==1.0.6 +structlog==23.2.0 sqlparse==0.4.4 temporalio==1.1.0 token-bucket==0.3.0 diff --git a/requirements.txt b/requirements.txt index 44eef0d14f9b6..86a4c6d7edab5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,8 @@ aiohttp==3.8.5 # openai aioitertools==0.11.0 # via aiobotocore +aiokafka==0.8.1 + # via -r requirements.in aiosignal==1.2.0 # via aiohttp amqp==5.1.1 @@ -39,6 +41,7 @@ async-generator==1.10 async-timeout==4.0.2 # via # aiohttp + # aiokafka # redis attrs==21.4.0 # via @@ -280,7 +283,9 @@ jsonschema==4.4.0 kafka-helper==0.2 # via -r requirements.in kafka-python==2.0.2 - # via -r requirements.in + # via + # -r requirements.in + # aiokafka kombu==5.3.2 # via # -r requirements.in @@ -324,6 +329,7 @@ outcome==1.1.0 # via trio packaging==23.1 # via + # aiokafka # google-cloud-bigquery # prance # snowflake-connector-python @@ -356,6 +362,8 @@ protobuf==4.22.1 # grpcio-status # proto-plus # temporalio +psycopg==3.1 + # via -r requirements.in psycopg2-binary==2.9.7 # via -r requirements.in ptyprocess==0.6.0 @@ -493,8 +501,10 @@ sqlparse==0.4.4 # django statshog==1.0.6 # via -r requirements.in -structlog==21.2.0 - # via django-structlog +structlog==23.2.0 + # via + # -r requirements.in + # django-structlog temporalio==1.1.0 # via -r requirements.in tenacity==6.1.0 @@ -521,6 +531,7 @@ types-s3transfer==0.6.1 # via boto3-stubs typing-extensions==4.7.1 # via + # psycopg # pydantic # pydantic-core # qrcode