Skip to content

Commit

Permalink
refactor: Postgres (+ Redshift) batch exports now async (#18631)
Browse files Browse the repository at this point in the history
* refactor: Postgres batch exports now async

* fix: Let's use from psycopg import sql instead

* test: Update Redshift tests

* fix: Typing issues

* fix: Main insert batch loop
  • Loading branch information
tomasfarias authored Nov 15, 2023
1 parent 3735e44 commit 70970ff
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 225 deletions.
70 changes: 70 additions & 0 deletions posthog/temporal/tests/batch_exports/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import psycopg
import pytest
import pytest_asyncio
from psycopg import sql


@pytest.fixture
Expand Down Expand Up @@ -40,3 +42,71 @@ async def truncate_events(clickhouse_client):
"""
yield
await clickhouse_client.execute_query("TRUNCATE TABLE IF EXISTS `sharded_events`")


@pytest_asyncio.fixture
async def setup_postgres_test_db(postgres_config):
"""Fixture to manage a database for Redshift export testing.
Managing a test database involves the following steps:
1. Creating a test database.
2. Initializing a connection to that database.
3. Creating a test schema.
4. Yielding the connection to be used in tests.
5. After tests, drop the test schema and any tables in it.
6. Drop the test database.
"""
connection = await psycopg.AsyncConnection.connect(
user=postgres_config["user"],
password=postgres_config["password"],
host=postgres_config["host"],
port=postgres_config["port"],
)
await connection.set_autocommit(True)

async with connection.cursor() as cursor:
await cursor.execute(
sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"),
(postgres_config["database"],),
)

if await cursor.fetchone() is None:
await cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"])))

await connection.close()

# We need a new connection to connect to the database we just created.
connection = await psycopg.AsyncConnection.connect(
user=postgres_config["user"],
password=postgres_config["password"],
host=postgres_config["host"],
port=postgres_config["port"],
dbname=postgres_config["database"],
)
await connection.set_autocommit(True)

async with connection.cursor() as cursor:
await cursor.execute(
sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"]))
)

yield

async with connection.cursor() as cursor:
await cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"])))

await connection.close()

# We need a new connection to drop the database, as we cannot drop the current database.
connection = await psycopg.AsyncConnection.connect(
user=postgres_config["user"],
password=postgres_config["password"],
host=postgres_config["host"],
port=postgres_config["port"],
)
await connection.set_autocommit(True)

async with connection.cursor() as cursor:
await cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"])))

await connection.close()
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from random import randint
from uuid import uuid4

import psycopg2
import psycopg
import pytest
import pytest_asyncio
from django.conf import settings
from django.test import override_settings
from psycopg2 import sql
from psycopg import sql
from temporalio import activity
from temporalio.client import WorkflowFailureError
from temporalio.common import RetryPolicy
Expand All @@ -35,15 +35,17 @@
]


def assert_events_in_postgres(connection, schema, table_name, events, exclude_events: list[str] | None = None):
async def assert_events_in_postgres(connection, schema, table_name, events, exclude_events: list[str] | None = None):
"""Assert provided events written to a given Postgres table."""
inserted_events = []

with connection.cursor() as cursor:
cursor.execute(sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format(sql.Identifier(schema, table_name)))
async with connection.cursor() as cursor:
await cursor.execute(
sql.SQL("SELECT * FROM {} ORDER BY event, timestamp").format(sql.Identifier(schema, table_name))
)
columns = [column.name for column in cursor.description]

for row in cursor.fetchall():
for row in await cursor.fetchall():
event = dict(zip(columns, row))
event["timestamp"] = dt.datetime.fromisoformat(event["timestamp"].isoformat())
inserted_events.append(event)
Expand All @@ -61,12 +63,12 @@ def assert_events_in_postgres(connection, schema, table_name, events, exclude_ev
"distinct_id": event.get("distinct_id"),
"elements": json.dumps(elements_chain),
"event": event.get("event"),
"ip": properties.get("$ip", None) if properties else None,
"ip": properties.get("$ip", "") if properties else "",
"properties": event.get("properties"),
"set": properties.get("$set", None) if properties else None,
"set_once": properties.get("$set_once", None) if properties else None,
# Kept for backwards compatibility, but not exported anymore.
"site_url": None,
"site_url": "",
# For compatibility with CH which doesn't parse timezone component, so we add it here assuming UTC.
"timestamp": dt.datetime.fromisoformat(event.get("timestamp") + "+00:00"),
"team_id": event.get("team_id"),
Expand Down Expand Up @@ -94,75 +96,19 @@ def postgres_config():
}


@pytest.fixture
def setup_test_db(postgres_config):
connection = psycopg2.connect(
user=postgres_config["user"],
password=postgres_config["password"],
host=postgres_config["host"],
port=postgres_config["port"],
)
connection.set_session(autocommit=True)

with connection.cursor() as cursor:
cursor.execute(
sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"),
(postgres_config["database"],),
)

if cursor.fetchone() is None:
cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"])))

connection.close()

# We need a new connection to connect to the database we just created.
connection = psycopg2.connect(
user=postgres_config["user"],
password=postgres_config["password"],
host=postgres_config["host"],
port=postgres_config["port"],
database=postgres_config["database"],
)
connection.set_session(autocommit=True)

with connection.cursor() as cursor:
cursor.execute(sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"])))

yield

with connection.cursor() as cursor:
cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"])))

connection.close()

# We need a new connection to drop the database, as we cannot drop the current database.
connection = psycopg2.connect(
user=postgres_config["user"],
password=postgres_config["password"],
host=postgres_config["host"],
port=postgres_config["port"],
)
connection.set_session(autocommit=True)

with connection.cursor() as cursor:
cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"])))

connection.close()


@pytest.fixture
def postgres_connection(postgres_config, setup_test_db):
connection = psycopg2.connect(
@pytest_asyncio.fixture
async def postgres_connection(postgres_config, setup_postgres_test_db):
connection = await psycopg.AsyncConnection.connect(
user=postgres_config["user"],
password=postgres_config["password"],
database=postgres_config["database"],
dbname=postgres_config["database"],
host=postgres_config["host"],
port=postgres_config["port"],
)

yield connection

connection.close()
await connection.close()


@pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True)
Expand Down Expand Up @@ -241,7 +187,7 @@ async def test_insert_into_postgres_activity_inserts_data_into_postgres_table(
with override_settings(BATCH_EXPORT_POSTGRES_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2):
await activity_environment.run(insert_into_postgres_activity, insert_inputs)

assert_events_in_postgres(
await assert_events_in_postgres(
connection=postgres_connection,
schema=postgres_config["schema"],
table_name="test_table",
Expand Down Expand Up @@ -329,7 +275,7 @@ async def test_postgres_export_workflow(
inputs = PostgresBatchExportInputs(
team_id=ateam.pk,
batch_export_id=str(postgres_batch_export.id),
data_interval_end="2023-04-25 14:30:00.000000",
data_interval_end=data_interval_end.isoformat(),
interval=interval,
**postgres_batch_export.destination.config,
)
Expand Down Expand Up @@ -362,7 +308,7 @@ async def test_postgres_export_workflow(
run = runs[0]
assert run.status == "Completed"

assert_events_in_postgres(
await assert_events_in_postgres(
postgres_connection,
postgres_config["schema"],
table_name,
Expand Down
Loading

0 comments on commit 70970ff

Please sign in to comment.