From 172ef815a00cda001c1233aa7b63c2b02b5c68ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Thu, 9 Nov 2023 02:21:26 +0100 Subject: [PATCH 1/5] refactor: Postgres batch exports now async --- .../test_postgres_batch_export_workflow.py | 88 +++++++++++-------- .../workflows/postgres_batch_export.py | 78 +++++++++------- .../workflows/redshift_batch_export.py | 35 ++++---- requirements.in | 1 + requirements.txt | 3 + 5 files changed, 115 insertions(+), 90 deletions(-) diff --git a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py index 1cdfb595b1e47..19de3979bff4b 100644 --- a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py @@ -4,12 +4,12 @@ from random import randint from uuid import uuid4 -import psycopg2 +import psycopg +import psycopg.sql import pytest import pytest_asyncio from django.conf import settings from django.test import override_settings -from psycopg2 import sql from temporalio import activity from temporalio.client import WorkflowFailureError from temporalio.common import RetryPolicy @@ -35,15 +35,19 @@ ] -def assert_events_in_postgres(connection, schema, table_name, events, exclude_events: list[str] | None = None): +async def assert_events_in_postgres(connection, schema, table_name, events, exclude_events: list[str] | None = None): """Assert provided events written to a given Postgres table.""" inserted_events = [] - with connection.cursor() as cursor: - cursor.execute(sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format(sql.Identifier(schema, table_name))) + async with connection.cursor() as cursor: + await cursor.execute( + psycopg.sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format( + psycopg.sql.Identifier(schema, table_name) + ) + ) columns = [column.name for column in cursor.description] - for row in cursor.fetchall(): + for row in await cursor.fetchall(): event = dict(zip(columns, row)) event["timestamp"] = dt.datetime.fromisoformat(event["timestamp"].isoformat()) inserted_events.append(event) @@ -61,12 +65,12 @@ def assert_events_in_postgres(connection, schema, table_name, events, exclude_ev "distinct_id": event.get("distinct_id"), "elements": json.dumps(elements_chain), "event": event.get("event"), - "ip": properties.get("$ip", None) if properties else None, + "ip": properties.get("$ip", "") if properties else "", "properties": event.get("properties"), "set": properties.get("$set", None) if properties else None, "set_once": properties.get("$set_once", None) if properties else None, # Kept for backwards compatibility, but not exported anymore. - "site_url": None, + "site_url": "", # For compatibility with CH which doesn't parse timezone component, so we add it here assuming UTC. "timestamp": dt.datetime.fromisoformat(event.get("timestamp") + "+00:00"), "team_id": event.get("team_id"), @@ -94,75 +98,83 @@ def postgres_config(): } -@pytest.fixture -def setup_test_db(postgres_config): - connection = psycopg2.connect( +@pytest_asyncio.fixture +async def setup_test_db(postgres_config): + connection = await psycopg.AsyncConnection.connect( user=postgres_config["user"], password=postgres_config["password"], host=postgres_config["host"], port=postgres_config["port"], ) - connection.set_session(autocommit=True) + await connection.set_autocommit(True) - with connection.cursor() as cursor: - cursor.execute( - sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), + async with connection.cursor() as cursor: + await cursor.execute( + psycopg.sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), (postgres_config["database"],), ) - if cursor.fetchone() is None: - cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + if await cursor.fetchone() is None: + await cursor.execute( + psycopg.sql.SQL("CREATE DATABASE {}").format(psycopg.sql.Identifier(postgres_config["database"])) + ) - connection.close() + await connection.close() # We need a new connection to connect to the database we just created. - connection = psycopg2.connect( + connection = await psycopg.AsyncConnection.connect( user=postgres_config["user"], password=postgres_config["password"], host=postgres_config["host"], port=postgres_config["port"], - database=postgres_config["database"], + dbname=postgres_config["database"], ) - connection.set_session(autocommit=True) + await connection.set_autocommit(True) - with connection.cursor() as cursor: - cursor.execute(sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"]))) + async with connection.cursor() as cursor: + await cursor.execute( + psycopg.sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(psycopg.sql.Identifier(postgres_config["schema"])) + ) yield - with connection.cursor() as cursor: - cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"]))) + async with connection.cursor() as cursor: + await cursor.execute( + psycopg.sql.SQL("DROP SCHEMA {} CASCADE").format(psycopg.sql.Identifier(postgres_config["schema"])) + ) - connection.close() + await connection.close() # We need a new connection to drop the database, as we cannot drop the current database. - connection = psycopg2.connect( + connection = await psycopg.AsyncConnection.connect( user=postgres_config["user"], password=postgres_config["password"], host=postgres_config["host"], port=postgres_config["port"], ) - connection.set_session(autocommit=True) + await connection.set_autocommit(True) - with connection.cursor() as cursor: - cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + async with connection.cursor() as cursor: + await cursor.execute( + psycopg.sql.SQL("DROP DATABASE {}").format(psycopg.sql.Identifier(postgres_config["database"])) + ) - connection.close() + await connection.close() -@pytest.fixture -def postgres_connection(postgres_config, setup_test_db): - connection = psycopg2.connect( +@pytest_asyncio.fixture +async def postgres_connection(postgres_config, setup_test_db): + connection = await psycopg.AsyncConnection.connect( user=postgres_config["user"], password=postgres_config["password"], - database=postgres_config["database"], + dbname=postgres_config["database"], host=postgres_config["host"], port=postgres_config["port"], ) yield connection - connection.close() + await connection.close() @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @@ -241,7 +253,7 @@ async def test_insert_into_postgres_activity_inserts_data_into_postgres_table( with override_settings(BATCH_EXPORT_POSTGRES_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2): await activity_environment.run(insert_into_postgres_activity, insert_inputs) - assert_events_in_postgres( + await assert_events_in_postgres( connection=postgres_connection, schema=postgres_config["schema"], table_name="test_table", @@ -362,7 +374,7 @@ async def test_postgres_export_workflow( run = runs[0] assert run.status == "Completed" - assert_events_in_postgres( + await assert_events_in_postgres( postgres_connection, postgres_config["schema"], table_name, diff --git a/posthog/temporal/workflows/postgres_batch_export.py b/posthog/temporal/workflows/postgres_batch_export.py index e12add206a3c0..4918748e06a21 100644 --- a/posthog/temporal/workflows/postgres_batch_export.py +++ b/posthog/temporal/workflows/postgres_batch_export.py @@ -1,14 +1,13 @@ -import asyncio import collections.abc import contextlib import datetime as dt import json +import typing from dataclasses import dataclass -import psycopg2 -import psycopg2.extensions +import psycopg from django.conf import settings -from psycopg2 import sql +from psycopg import sql from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -29,13 +28,13 @@ from posthog.temporal.workflows.metrics import get_bytes_exported_metric, get_rows_exported_metric -@contextlib.contextmanager -def postgres_connection(inputs) -> collections.abc.Iterator[psycopg2.extensions.connection]: +@contextlib.asynccontextmanager +async def postgres_connection(inputs) -> typing.AsyncIterator[psycopg.AsyncConnection]: """Manage a Postgres connection.""" - connection = psycopg2.connect( + connection = await psycopg.AsyncConnection.connect( user=inputs.user, password=inputs.password, - database=inputs.database, + dbname=inputs.database, host=inputs.host, port=inputs.port, # The 'hasSelfSignedCert' parameter in the postgres-plugin was provided mainly @@ -48,17 +47,17 @@ def postgres_connection(inputs) -> collections.abc.Iterator[psycopg2.extensions. try: yield connection except Exception: - connection.rollback() + await connection.rollback() raise else: - connection.commit() + await connection.commit() finally: - connection.close() + await connection.close() async def copy_tsv_to_postgres( tsv_file, - postgres_connection: psycopg2.extensions.connection, + postgres_connection: psycopg.AsyncConnection, schema: str, table_name: str, schema_columns: list[str], @@ -67,36 +66,37 @@ async def copy_tsv_to_postgres( Arguments: tsv_file: A file-like object to interpret as TSV to copy its contents. - postgres_connection: A connection to Postgres as setup by psycopg2. + postgres_connection: A connection to Postgres as setup by psycopg. schema: An existing schema where to create the table. table_name: The name of the table to create. schema_columns: A list of column names. """ tsv_file.seek(0) - with postgres_connection.cursor() as cursor: + async with postgres_connection.cursor() as cursor: if schema: - cursor.execute(sql.SQL("SET search_path TO {schema}").format(schema=sql.Identifier(schema))) - await asyncio.to_thread( - cursor.copy_from, - tsv_file, - table_name, - null="", - columns=schema_columns, - ) + await cursor.execute(sql.SQL("SET search_path TO {schema}").format(schema=sql.Identifier(schema))) + async with cursor.copy( + sql.SQL("COPY {table_name} ({fields}) FROM STDIN WITH DELIMITER AS '\t'").format( + table_name=sql.Identifier(table_name), + fields=sql.SQL(",").join((sql.Identifier(column) for column in schema_columns)), + ) + ) as copy: + while data := tsv_file.read(): + await copy.write(data) Field = tuple[str, str] Fields = collections.abc.Iterable[Field] -def create_table_in_postgres( - postgres_connection: psycopg2.extensions.connection, schema: str | None, table_name: str, fields: Fields +async def create_table_in_postgres( + postgres_connection: psycopg.AsyncConnection, schema: str | None, table_name: str, fields: Fields ) -> None: """Create a table in a Postgres database if it doesn't exist already. Arguments: - postgres_connection: A connection to Postgres as setup by psycopg2. + postgres_connection: A connection to Postgres as setup by psycopg. schema: An existing schema where to create the table. table_name: The name of the table to create. fields: An iterable of (name, type) tuples representing the fields of the table. @@ -106,8 +106,8 @@ def create_table_in_postgres( else: table_identifier = sql.Identifier(table_name) - with postgres_connection.cursor() as cursor: - cursor.execute( + async with postgres_connection.cursor() as cursor: + await cursor.execute( sql.SQL( """ CREATE TABLE IF NOT EXISTS {table} ( @@ -117,7 +117,13 @@ def create_table_in_postgres( ).format( table=table_identifier, fields=sql.SQL(",").join( - sql.SQL("{field} {type}").format(field=sql.Identifier(field), type=sql.SQL(field_type)) + # typing.LiteralString is not available in Python 3.10. + # So, we ignore it for now. + # This is safe as we are hardcoding the type values anyways. + sql.SQL("{field} {type}").format( + field=sql.Identifier(field), + type=sql.SQL(field_type), # type: ignore + ) for field, field_type in fields ), ) @@ -184,8 +190,8 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs): exclude_events=inputs.exclude_events, include_events=inputs.include_events, ) - with postgres_connection(inputs) as connection: - create_table_in_postgres( + async with postgres_connection(inputs) as connection: + await create_table_in_postgres( connection, schema=inputs.schema, table_name=inputs.table_name, @@ -219,10 +225,16 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs): ] json_columns = ("properties", "elements", "set", "set_once") + rows_exported = get_rows_exported_metric() + bytes_exported = get_bytes_exported_metric() + with BatchExportTemporaryFile() as pg_file: - with postgres_connection(inputs) as connection: - rows_exported = get_rows_exported_metric() - bytes_exported = get_bytes_exported_metric() + async with postgres_connection(inputs) as connection: + for result in results_iterator: + row = { + key: json.dumps(result[key]) if key in json_columns else result[key] for key in schema_columns + } + pg_file.write_records_to_tsv([row], fieldnames=schema_columns) async def flush_to_postgres(): logger.debug( diff --git a/posthog/temporal/workflows/redshift_batch_export.py b/posthog/temporal/workflows/redshift_batch_export.py index da5b780111f2c..cf4034ca255e3 100644 --- a/posthog/temporal/workflows/redshift_batch_export.py +++ b/posthog/temporal/workflows/redshift_batch_export.py @@ -4,10 +4,7 @@ import typing from dataclasses import dataclass -import psycopg2 -import psycopg2.extensions -import psycopg2.extras -from psycopg2 import sql +import psycopg from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -32,9 +29,9 @@ ) -def insert_records_to_redshift( +async def insert_records_to_redshift( records: collections.abc.Iterator[dict[str, typing.Any]], - redshift_connection: psycopg2.extensions.connection, + redshift_connection: psycopg.AsyncConnection, schema: str, table: str, batch_size: int = 100, @@ -62,18 +59,18 @@ def insert_records_to_redshift( columns = batch[0].keys() - with redshift_connection.cursor() as cursor: - query = sql.SQL("INSERT INTO {table} ({fields}) VALUES {placeholder}").format( - table=sql.Identifier(schema, table), - fields=sql.SQL(", ").join(map(sql.Identifier, columns)), - placeholder=sql.Placeholder(), + async with redshift_connection.cursor() as cursor: + query = psycopg.sql.SQL("INSERT INTO {table} ({fields}) VALUES {placeholder}").format( + table=psycopg.sql.Identifier(schema, table), + fields=psycopg.sql.SQL(", ").join(map(psycopg.sql.Identifier, columns)), + placeholder=psycopg.sql.Placeholder(), ) - template = sql.SQL("({})").format(sql.SQL(", ").join(map(sql.Placeholder, columns))) + template = psycopg.sql.SQL("({})").format(psycopg.sql.SQL(", ").join(map(psycopg.sql.Placeholder, columns))) rows_exported = get_rows_exported_metric() - def flush_to_redshift(): - psycopg2.extras.execute_values(cursor, query, batch, template) + async def flush_to_redshift(): + await cursor.execute_many(cursor, query, batch, template) rows_exported.add(len(batch)) # It would be nice to record BYTES_EXPORTED for Redshift, but it's not worth estimating # the byte size of each batch the way things are currently written. We can revisit this @@ -85,11 +82,11 @@ def flush_to_redshift(): if len(batch) < batch_size: continue - flush_to_redshift() + await flush_to_redshift() batch = [] if len(batch) > 0: - flush_to_redshift() + await flush_to_redshift() @dataclass @@ -160,7 +157,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs): ) properties_type = "VARCHAR(65535)" if inputs.properties_data_type == "varchar" else "SUPER" - with postgres_connection(inputs) as connection: + async with postgres_connection(inputs) as connection: create_table_in_postgres( connection, schema=inputs.schema, @@ -202,8 +199,8 @@ def map_to_record(row: dict) -> dict: for key in schema_columns } - with postgres_connection(inputs) as connection: - insert_records_to_redshift( + async with postgres_connection(inputs) as connection: + await insert_records_to_redshift( (map_to_record(result) for result in results_iterator), connection, inputs.schema, inputs.table_name ) diff --git a/requirements.in b/requirements.in index 7455538c55210..574ff57900d39 100644 --- a/requirements.in +++ b/requirements.in @@ -59,6 +59,7 @@ Pillow==9.2.0 posthoganalytics==3.0.1 prance==0.22.2.22.0 psycopg2-binary==2.9.7 +psycopg==3.1 pyarrow==12.0.1 pydantic==2.3.0 pyjwt==2.4.0 diff --git a/requirements.txt b/requirements.txt index bda638053b17d..1efa70e5318f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -362,6 +362,8 @@ protobuf==4.22.1 # grpcio-status # proto-plus # temporalio +psycopg==3.1 + # via -r requirements.in psycopg2-binary==2.9.7 # via -r requirements.in ptyprocess==0.6.0 @@ -529,6 +531,7 @@ types-s3transfer==0.6.1 # via boto3-stubs typing-extensions==4.7.1 # via + # psycopg # pydantic # pydantic-core # qrcode From c488955ba2386fef8bca90808d5867099b15757c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Thu, 9 Nov 2023 02:30:17 +0100 Subject: [PATCH 2/5] fix: Let's use from psycopg import sql instead --- .../test_postgres_batch_export_workflow.py | 22 ++++------- .../workflows/postgres_batch_export.py | 2 +- .../workflows/redshift_batch_export.py | 38 +++++++++++-------- 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py index 19de3979bff4b..924982404d3bf 100644 --- a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py @@ -5,11 +5,11 @@ from uuid import uuid4 import psycopg -import psycopg.sql import pytest import pytest_asyncio from django.conf import settings from django.test import override_settings +from psycopg import sql from temporalio import activity from temporalio.client import WorkflowFailureError from temporalio.common import RetryPolicy @@ -41,9 +41,7 @@ async def assert_events_in_postgres(connection, schema, table_name, events, excl async with connection.cursor() as cursor: await cursor.execute( - psycopg.sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format( - psycopg.sql.Identifier(schema, table_name) - ) + sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format(sql.Identifier(schema, table_name)) ) columns = [column.name for column in cursor.description] @@ -110,14 +108,12 @@ async def setup_test_db(postgres_config): async with connection.cursor() as cursor: await cursor.execute( - psycopg.sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), + sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), (postgres_config["database"],), ) if await cursor.fetchone() is None: - await cursor.execute( - psycopg.sql.SQL("CREATE DATABASE {}").format(psycopg.sql.Identifier(postgres_config["database"])) - ) + await cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"]))) await connection.close() @@ -133,15 +129,13 @@ async def setup_test_db(postgres_config): async with connection.cursor() as cursor: await cursor.execute( - psycopg.sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(psycopg.sql.Identifier(postgres_config["schema"])) + sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"])) ) yield async with connection.cursor() as cursor: - await cursor.execute( - psycopg.sql.SQL("DROP SCHEMA {} CASCADE").format(psycopg.sql.Identifier(postgres_config["schema"])) - ) + await cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"]))) await connection.close() @@ -155,9 +149,7 @@ async def setup_test_db(postgres_config): await connection.set_autocommit(True) async with connection.cursor() as cursor: - await cursor.execute( - psycopg.sql.SQL("DROP DATABASE {}").format(psycopg.sql.Identifier(postgres_config["database"])) - ) + await cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"]))) await connection.close() diff --git a/posthog/temporal/workflows/postgres_batch_export.py b/posthog/temporal/workflows/postgres_batch_export.py index 4918748e06a21..025c29c6421ea 100644 --- a/posthog/temporal/workflows/postgres_batch_export.py +++ b/posthog/temporal/workflows/postgres_batch_export.py @@ -122,7 +122,7 @@ async def create_table_in_postgres( # This is safe as we are hardcoding the type values anyways. sql.SQL("{field} {type}").format( field=sql.Identifier(field), - type=sql.SQL(field_type), # type: ignore + type=sql.SQL(field_type), ) for field, field_type in fields ), diff --git a/posthog/temporal/workflows/redshift_batch_export.py b/posthog/temporal/workflows/redshift_batch_export.py index cf4034ca255e3..c0b5814f99d95 100644 --- a/posthog/temporal/workflows/redshift_batch_export.py +++ b/posthog/temporal/workflows/redshift_batch_export.py @@ -1,10 +1,12 @@ import collections.abc import datetime as dt +import itertools import json import typing from dataclasses import dataclass import psycopg +from psycopg import sql from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -55,38 +57,42 @@ async def insert_records_to_redshift( make us go OOM or exceed Redshift's SQL statement size limit (16MB). Setting this too low can significantly affect performance due to Redshift's poor handling of INSERTs. """ - batch = [next(records)] + first_record = next(records) + columns = first_record.keys() - columns = batch[0].keys() + pre_query = sql.SQL("INSERT INTO {table} ({fields}) VALUES").format( + table=sql.Identifier(schema, table), + fields=sql.SQL(", ").join(map(sql.Identifier, columns)), + ) + template = sql.SQL("({})").format(sql.SQL(", ").join(map(sql.Placeholder, columns))) + + redshift_connection.cursor_factory = psycopg.AsyncClientCursor async with redshift_connection.cursor() as cursor: - query = psycopg.sql.SQL("INSERT INTO {table} ({fields}) VALUES {placeholder}").format( - table=psycopg.sql.Identifier(schema, table), - fields=psycopg.sql.SQL(", ").join(map(psycopg.sql.Identifier, columns)), - placeholder=psycopg.sql.Placeholder(), - ) - template = psycopg.sql.SQL("({})").format(psycopg.sql.SQL(", ").join(map(psycopg.sql.Placeholder, columns))) + batch = [pre_query.as_string(cursor).encode("utf-8")] rows_exported = get_rows_exported_metric() - async def flush_to_redshift(): - await cursor.execute_many(cursor, query, batch, template) + async def flush_to_redshift(batch): + await cursor.execute(b"".join(batch)) rows_exported.add(len(batch)) # It would be nice to record BYTES_EXPORTED for Redshift, but it's not worth estimating # the byte size of each batch the way things are currently written. We can revisit this # in the future if we decide it's useful enough. - for record in records: - batch.append(record) + for record in itertools.chain([first_record], records): + batch.append(cursor.mogrify(template, record).encode("utf-8")) if len(batch) < batch_size: + batch.append(b",") continue - await flush_to_redshift() - batch = [] + if len(batch) > 0: + await flush_to_redshift(batch) + batch = [pre_query.as_string(cursor).encode("utf-8")] if len(batch) > 0: - await flush_to_redshift() + await flush_to_redshift(batch[:-1]) @dataclass @@ -158,7 +164,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs): properties_type = "VARCHAR(65535)" if inputs.properties_data_type == "varchar" else "SUPER" async with postgres_connection(inputs) as connection: - create_table_in_postgres( + await create_table_in_postgres( connection, schema=inputs.schema, table_name=inputs.table_name, From bdb34f2d18d30079e7e8c9f0d2239b79e5f130a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Thu, 9 Nov 2023 14:46:13 +0100 Subject: [PATCH 3/5] test: Update Redshift tests --- .../temporal/tests/batch_exports/conftest.py | 70 +++++++++ .../test_postgres_batch_export_workflow.py | 62 +------- .../test_redshift_batch_export_workflow.py | 134 ++++++------------ 3 files changed, 115 insertions(+), 151 deletions(-) diff --git a/posthog/temporal/tests/batch_exports/conftest.py b/posthog/temporal/tests/batch_exports/conftest.py index 9fafd040b027d..98250309e33cc 100644 --- a/posthog/temporal/tests/batch_exports/conftest.py +++ b/posthog/temporal/tests/batch_exports/conftest.py @@ -1,5 +1,7 @@ +import psycopg import pytest import pytest_asyncio +from psycopg import sql @pytest.fixture @@ -40,3 +42,71 @@ async def truncate_events(clickhouse_client): """ yield await clickhouse_client.execute_query("TRUNCATE TABLE IF EXISTS `sharded_events`") + + +@pytest_asyncio.fixture +async def setup_postgres_test_db(postgres_config): + """Fixture to manage a database for Redshift export testing. + + Managing a test database involves the following steps: + 1. Creating a test database. + 2. Initializing a connection to that database. + 3. Creating a test schema. + 4. Yielding the connection to be used in tests. + 5. After tests, drop the test schema and any tables in it. + 6. Drop the test database. + """ + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), + (postgres_config["database"],), + ) + + if await cursor.fetchone() is None: + await cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + + await connection.close() + + # We need a new connection to connect to the database we just created. + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + dbname=postgres_config["database"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"])) + ) + + yield + + async with connection.cursor() as cursor: + await cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"]))) + + await connection.close() + + # We need a new connection to drop the database, as we cannot drop the current database. + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + + await connection.close() diff --git a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py index 924982404d3bf..6a70e9f2eb74c 100644 --- a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py @@ -97,65 +97,7 @@ def postgres_config(): @pytest_asyncio.fixture -async def setup_test_db(postgres_config): - connection = await psycopg.AsyncConnection.connect( - user=postgres_config["user"], - password=postgres_config["password"], - host=postgres_config["host"], - port=postgres_config["port"], - ) - await connection.set_autocommit(True) - - async with connection.cursor() as cursor: - await cursor.execute( - sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), - (postgres_config["database"],), - ) - - if await cursor.fetchone() is None: - await cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"]))) - - await connection.close() - - # We need a new connection to connect to the database we just created. - connection = await psycopg.AsyncConnection.connect( - user=postgres_config["user"], - password=postgres_config["password"], - host=postgres_config["host"], - port=postgres_config["port"], - dbname=postgres_config["database"], - ) - await connection.set_autocommit(True) - - async with connection.cursor() as cursor: - await cursor.execute( - sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"])) - ) - - yield - - async with connection.cursor() as cursor: - await cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"]))) - - await connection.close() - - # We need a new connection to drop the database, as we cannot drop the current database. - connection = await psycopg.AsyncConnection.connect( - user=postgres_config["user"], - password=postgres_config["password"], - host=postgres_config["host"], - port=postgres_config["port"], - ) - await connection.set_autocommit(True) - - async with connection.cursor() as cursor: - await cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"]))) - - await connection.close() - - -@pytest_asyncio.fixture -async def postgres_connection(postgres_config, setup_test_db): +async def postgres_connection(postgres_config, setup_postgres_test_db): connection = await psycopg.AsyncConnection.connect( user=postgres_config["user"], password=postgres_config["password"], @@ -333,7 +275,7 @@ async def test_postgres_export_workflow( inputs = PostgresBatchExportInputs( team_id=ateam.pk, batch_export_id=str(postgres_batch_export.id), - data_interval_end="2023-04-25 14:30:00.000000", + data_interval_end=data_interval_end.isoformat(), interval=interval, **postgres_batch_export.destination.config, ) diff --git a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py index 2888484077371..176b487ff94a0 100644 --- a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py @@ -1,15 +1,16 @@ import datetime as dt import json import os +import warnings from random import randint from uuid import uuid4 -import psycopg2 +import psycopg import pytest import pytest_asyncio from django.conf import settings from django.test import override_settings -from psycopg2 import sql +from psycopg import sql from temporalio.common import RetryPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker @@ -33,24 +34,24 @@ "REDSHIFT_HOST", ) -SKIP_IF_MISSING_REQUIRED_ENV_VARS = pytest.mark.skipif( - any(env_var not in os.environ for env_var in REQUIRED_ENV_VARS), - reason="Redshift required env vars are not set", -) +MISSING_REQUIRED_ENV_VARS = any(env_var not in os.environ for env_var in REQUIRED_ENV_VARS) + -pytestmark = [SKIP_IF_MISSING_REQUIRED_ENV_VARS, pytest.mark.django_db, pytest.mark.asyncio] +pytestmark = [pytest.mark.django_db, pytest.mark.asyncio] -def assert_events_in_redshift(connection, schema, table_name, events, exclude_events: list[str] | None = None): +async def assert_events_in_redshift(connection, schema, table_name, events, exclude_events: list[str] | None = None): """Assert provided events written to a given Redshift table.""" inserted_events = [] - with connection.cursor() as cursor: - cursor.execute(sql.SQL("SELECT * FROM {} ORDER BY timestamp").format(sql.Identifier(schema, table_name))) + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format(sql.Identifier(schema, table_name)) + ) columns = [column.name for column in cursor.description] - for row in cursor.fetchall(): + for row in await cursor.fetchall(): event = dict(zip(columns, row)) event["timestamp"] = dt.datetime.fromisoformat(event["timestamp"].isoformat()) inserted_events.append(event) @@ -81,7 +82,7 @@ def assert_events_in_redshift(connection, schema, table_name, events, exclude_ev } expected_events.append(expected_event) - expected_events.sort(key=lambda x: x["timestamp"]) + expected_events.sort(key=lambda x: (x["event"], x["timestamp"])) assert len(inserted_events) == len(expected_events) # First check one event, the first one, so that we can get a nice diff if @@ -94,17 +95,26 @@ def assert_events_in_redshift(connection, schema, table_name, events, exclude_ev def redshift_config(): """Fixture to provide a default configuration for Redshift batch exports. - Reads required env vars to construct configuration. + Reads required env vars to construct configuration, but if not present + we default to local development PostgreSQL database, which should be mostly compatible. """ - user = os.environ["REDSHIFT_USER"] - password = os.environ["REDSHIFT_PASSWORD"] - host = os.environ["REDSHIFT_HOST"] - port = os.environ.get("REDSHIFT_PORT", "5439") + if MISSING_REQUIRED_ENV_VARS: + user = settings.PG_USER + password = settings.PG_PASSWORD + host = settings.PG_HOST + port = int(settings.PG_PORT) + warnings.warn("Missing required Redshift env vars. Running tests against local PG database.", stacklevel=1) + + else: + user = os.environ["REDSHIFT_USER"] + password = os.environ["REDSHIFT_PASSWORD"] + host = os.environ["REDSHIFT_HOST"] + port = os.environ.get("REDSHIFT_PORT", "5439") return { "user": user, "password": password, - "database": "exports_test_database", + "database": "dev", "schema": "exports_test_schema", "host": host, "port": int(port), @@ -112,89 +122,30 @@ def redshift_config(): @pytest.fixture -def setup_test_db(redshift_config): - """Fixture to manage a database for Redshift export testing. - - Managing a test database involves the following steps: - 1. Creating a test database. - 2. Initializing a connection to that database. - 3. Creating a test schema. - 4. Yielding the connection to be used in tests. - 5. After tests, drop the test schema and any tables in it. - 6. Drop the test database. - """ - connection = psycopg2.connect( - user=redshift_config["user"], - password=redshift_config["password"], - host=redshift_config["host"], - port=redshift_config["port"], - database="dev", - ) - connection.set_session(autocommit=True) +def postgres_config(redshift_config): + """We shadow this name so that setup_postgres_test_db works with Redshift.""" + return redshift_config - with connection.cursor() as cursor: - cursor.execute(sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), (redshift_config["database"],)) - - if cursor.fetchone() is None: - cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(redshift_config["database"]))) - - connection.close() - - # We need a new connection to connect to the database we just created. - connection = psycopg2.connect( - user=redshift_config["user"], - password=redshift_config["password"], - host=redshift_config["host"], - port=redshift_config["port"], - database=redshift_config["database"], - ) - connection.set_session(autocommit=True) - with connection.cursor() as cursor: - cursor.execute(sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(redshift_config["schema"]))) - - yield - - with connection.cursor() as cursor: - cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(redshift_config["schema"]))) - - connection.close() - - # We need a new connection to drop the database, as we cannot drop the current database. - connection = psycopg2.connect( - user=redshift_config["user"], - password=redshift_config["password"], - host=redshift_config["host"], - port=redshift_config["port"], - database="dev", - ) - connection.set_session(autocommit=True) - - with connection.cursor() as cursor: - cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(redshift_config["database"]))) - - connection.close() - - -@pytest.fixture -def psycopg2_connection(redshift_config, setup_test_db): +@pytest_asyncio.fixture +async def psycopg_connection(redshift_config, setup_postgres_test_db): """Fixture to manage a psycopg2 connection.""" - connection = psycopg2.connect( + connection = await psycopg.AsyncConnection.connect( user=redshift_config["user"], password=redshift_config["password"], - database=redshift_config["database"], + dbname=redshift_config["database"], host=redshift_config["host"], port=redshift_config["port"], ) yield connection - connection.close() + await connection.close() @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) async def test_insert_into_redshift_activity_inserts_data_into_redshift_table( - clickhouse_client, activity_environment, psycopg2_connection, redshift_config, exclude_events + clickhouse_client, activity_environment, psycopg_connection, redshift_config, exclude_events ): """Test that the insert_into_redshift_activity function inserts data into a Redshift table. @@ -265,8 +216,8 @@ async def test_insert_into_redshift_activity_inserts_data_into_redshift_table( await activity_environment.run(insert_into_redshift_activity, insert_inputs) - assert_events_in_redshift( - connection=psycopg2_connection, + await assert_events_in_redshift( + connection=psycopg_connection, schema=redshift_config["schema"], table_name="test_table", events=events + events_with_no_properties, @@ -308,11 +259,12 @@ async def redshift_batch_export(ateam, table_name, redshift_config, interval, ex async def test_redshift_export_workflow( clickhouse_client, redshift_config, - psycopg2_connection, + psycopg_connection, interval, redshift_batch_export, ateam, exclude_events, + table_name, ): """Test Redshift Export Workflow end-to-end. @@ -385,8 +337,8 @@ async def test_redshift_export_workflow( run = runs[0] assert run.status == "Completed" - assert_events_in_redshift( - psycopg2_connection, + await assert_events_in_redshift( + psycopg_connection, redshift_config["schema"], table_name, events=events, From 07d438b0cd1edd525a90d9d72ea3aa9d5ba6209d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Thu, 9 Nov 2023 15:03:43 +0100 Subject: [PATCH 4/5] fix: Typing issues --- .../workflows/redshift_batch_export.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/posthog/temporal/workflows/redshift_batch_export.py b/posthog/temporal/workflows/redshift_batch_export.py index c0b5814f99d95..eab68d5eeb2ae 100644 --- a/posthog/temporal/workflows/redshift_batch_export.py +++ b/posthog/temporal/workflows/redshift_batch_export.py @@ -1,4 +1,5 @@ import collections.abc +import contextlib import datetime as dt import itertools import json @@ -68,7 +69,7 @@ async def insert_records_to_redshift( redshift_connection.cursor_factory = psycopg.AsyncClientCursor - async with redshift_connection.cursor() as cursor: + async with async_client_cursor_from_connection(redshift_connection) as cursor: batch = [pre_query.as_string(cursor).encode("utf-8")] rows_exported = get_rows_exported_metric() @@ -95,6 +96,27 @@ async def flush_to_redshift(batch): await flush_to_redshift(batch[:-1]) +@contextlib.asynccontextmanager +async def async_client_cursor_from_connection( + psycopg_connection: psycopg.AsyncConnection, +) -> typing.AsyncIterator[psycopg.AsyncClientCursor]: + """Yield a AsyncClientCursor from a psycopg.AsyncConnection. + + Keeps track of the current cursor_factory to set it after we are done. + """ + current_factory = psycopg_connection.cursor_factory + psycopg_connection.cursor_factory = psycopg.AsyncClientCursor + + try: + async with psycopg_connection.cursor() as cursor: + # Not a fan of typing.cast, but we know this is an psycopg.AsyncClientCursor + # as we have just set cursor_factory. + cursor = typing.cast(psycopg.AsyncClientCursor, cursor) + yield cursor + finally: + psycopg_connection.cursor_factory = current_factory + + @dataclass class RedshiftInsertInputs(PostgresInsertInputs): """Inputs for Redshift insert activity. From ee69926487b0646f9d903cd793f2c9a3aca5d329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Wed, 15 Nov 2023 00:28:04 +0100 Subject: [PATCH 5/5] fix: Main insert batch loop --- posthog/temporal/workflows/redshift_batch_export.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/posthog/temporal/workflows/redshift_batch_export.py b/posthog/temporal/workflows/redshift_batch_export.py index eab68d5eeb2ae..06843289aee5e 100644 --- a/posthog/temporal/workflows/redshift_batch_export.py +++ b/posthog/temporal/workflows/redshift_batch_export.py @@ -66,17 +66,14 @@ async def insert_records_to_redshift( fields=sql.SQL(", ").join(map(sql.Identifier, columns)), ) template = sql.SQL("({})").format(sql.SQL(", ").join(map(sql.Placeholder, columns))) - - redshift_connection.cursor_factory = psycopg.AsyncClientCursor + rows_exported = get_rows_exported_metric() async with async_client_cursor_from_connection(redshift_connection) as cursor: batch = [pre_query.as_string(cursor).encode("utf-8")] - rows_exported = get_rows_exported_metric() - async def flush_to_redshift(batch): await cursor.execute(b"".join(batch)) - rows_exported.add(len(batch)) + rows_exported.add(len(batch) - 1) # It would be nice to record BYTES_EXPORTED for Redshift, but it's not worth estimating # the byte size of each batch the way things are currently written. We can revisit this # in the future if we decide it's useful enough. @@ -88,7 +85,6 @@ async def flush_to_redshift(batch): batch.append(b",") continue - if len(batch) > 0: await flush_to_redshift(batch) batch = [pre_query.as_string(cursor).encode("utf-8")]