From fe5eb46c95189893bf41085cbee401e4a8bc61dd Mon Sep 17 00:00:00 2001 From: eric Date: Thu, 25 Jan 2024 19:28:20 -0500 Subject: [PATCH] fix test --- .../pipelines/postgres/__init__.py | 36 +--------- .../pipelines/postgres/helpers.py | 2 + .../temporal/tests/external_data/conftest.py | 71 +++++++++++++++++++ .../test_external_data_job.py | 56 ++++++++++++--- .../api/test/test_external_data_source.py | 2 +- .../warehouse/models/external_data_schema.py | 2 + 6 files changed, 124 insertions(+), 45 deletions(-) create mode 100644 posthog/temporal/tests/external_data/conftest.py rename posthog/temporal/tests/{ => external_data}/test_external_data_job.py (92%) diff --git a/posthog/temporal/data_imports/pipelines/postgres/__init__.py b/posthog/temporal/data_imports/pipelines/postgres/__init__.py index d515322e9fd33..438b25fbe9dac 100644 --- a/posthog/temporal/data_imports/pipelines/postgres/__init__.py +++ b/posthog/temporal/data_imports/pipelines/postgres/__init__.py @@ -1,7 +1,7 @@ """Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads.""" from typing import List, Optional, Union, Iterable, Any -from sqlalchemy import MetaData, Table +from sqlalchemy import MetaData, Table, text from sqlalchemy.engine import Engine import dlt @@ -70,37 +70,3 @@ def sql_database( primary_key=get_primary_key(table), spec=SqlDatabaseTableConfiguration, )(engine, table) - - -@dlt.common.configuration.with_config(sections=("sources", "sql_database"), spec=SqlTableResourceConfiguration) -def sql_table( - credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, - table: str = dlt.config.value, - schema: Optional[str] = dlt.config.value, - metadata: Optional[MetaData] = None, - incremental: Optional[dlt.sources.incremental[Any]] = None, -) -> DltResource: - """ - A dlt resource which loads data from an SQL database table using SQLAlchemy. - - Args: - credentials (Union[ConnectionStringCredentials, Engine, str]): Database credentials or an `Engine` instance representing the database connection. - table (str): Name of the table to load. - schema (Optional[str]): Optional name of the schema the table belongs to. - metadata (Optional[MetaData]): Optional `sqlalchemy.MetaData` instance. If provided, the `schema` argument is ignored. - incremental (Optional[dlt.sources.incremental[Any]]): Option to enable incremental loading for the table. - E.g., `incremental=dlt.sources.incremental('updated_at', pendulum.parse('2022-01-01T00:00:00Z'))` - write_disposition (str): Write disposition of the resource. - - Returns: - DltResource: The dlt resource for loading data from the SQL database table. - """ - engine = engine_from_credentials(credentials) - engine.execution_options(stream_results=True) - metadata = metadata or MetaData(schema=schema) - - table_obj = Table(table, metadata, autoload_with=engine) - - return dlt.resource(table_rows, name=table_obj.name, primary_key=get_primary_key(table_obj))( - engine, table_obj, incremental=incremental - ) diff --git a/posthog/temporal/data_imports/pipelines/postgres/helpers.py b/posthog/temporal/data_imports/pipelines/postgres/helpers.py index ac23217a96ccc..7d45a6df7e302 100644 --- a/posthog/temporal/data_imports/pipelines/postgres/helpers.py +++ b/posthog/temporal/data_imports/pipelines/postgres/helpers.py @@ -93,6 +93,8 @@ def table_rows( loader = TableLoader(engine, table, incremental=incremental, chunk_size=chunk_size) yield from loader.load_rows() + engine.dispose() + def engine_from_credentials(credentials: Union[ConnectionStringCredentials, Engine, str]) -> Engine: if isinstance(credentials, Engine): diff --git a/posthog/temporal/tests/external_data/conftest.py b/posthog/temporal/tests/external_data/conftest.py new file mode 100644 index 0000000000000..1d2fbcf47b8f0 --- /dev/null +++ b/posthog/temporal/tests/external_data/conftest.py @@ -0,0 +1,71 @@ +import psycopg +import pytest_asyncio +from psycopg import sql + + +@pytest_asyncio.fixture +async def setup_postgres_test_db(postgres_config): + """Fixture to manage a database for Redshift export testing. + + Managing a test database involves the following steps: + 1. Creating a test database. + 2. Initializing a connection to that database. + 3. Creating a test schema. + 4. Yielding the connection to be used in tests. + 5. After tests, drop the test schema and any tables in it. + 6. Drop the test database. + """ + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"), + (postgres_config["database"],), + ) + + if await cursor.fetchone() is None: + await cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + + await connection.close() + + # We need a new connection to connect to the database we just created. + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + dbname=postgres_config["database"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute( + sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"])) + ) + + yield + + async with connection.cursor() as cursor: + await cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"]))) + + await connection.close() + + # We need a new connection to drop the database, as we cannot drop the current database. + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + await connection.set_autocommit(True) + + async with connection.cursor() as cursor: + await cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"]))) + + await connection.close() diff --git a/posthog/temporal/tests/test_external_data_job.py b/posthog/temporal/tests/external_data/test_external_data_job.py similarity index 92% rename from posthog/temporal/tests/test_external_data_job.py rename to posthog/temporal/tests/external_data/test_external_data_job.py index ad17b1dfc48c0..a4574502ff9d4 100644 --- a/posthog/temporal/tests/test_external_data_job.py +++ b/posthog/temporal/tests/external_data/test_external_data_job.py @@ -41,6 +41,7 @@ import functools from django.conf import settings import asyncio +import psycopg BUCKET_NAME = "test-external-data-jobs" SESSION = aioboto3.Session() @@ -89,6 +90,33 @@ async def minio_client(bucket_name): await minio_client.delete_bucket(Bucket=bucket_name) +@pytest.fixture +def postgres_config(): + return { + "user": settings.PG_USER, + "password": settings.PG_PASSWORD, + "database": "external_data_database", + "schema": "external_data_schema", + "host": settings.PG_HOST, + "port": int(settings.PG_PORT), + } + + +@pytest_asyncio.fixture +async def postgres_connection(postgres_config, setup_postgres_test_db): + connection = await psycopg.AsyncConnection.connect( + user=postgres_config["user"], + password=postgres_config["password"], + dbname=postgres_config["database"], + host=postgres_config["host"], + port=postgres_config["port"], + ) + + yield connection + + await connection.close() + + @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_create_external_job_activity(activity_environment, team, **kwargs): @@ -493,7 +521,17 @@ async def mock_async_func(inputs): @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio -async def test_run_postgres_job(activity_environment, team, minio_client, **kwargs): +async def test_run_postgres_job( + activity_environment, team, minio_client, postgres_connection, postgres_config, **kwargs +): + await postgres_connection.execute( + "CREATE TABLE IF NOT EXISTS {schema}.posthog_test (id integer)".format(schema=postgres_config["schema"]) + ) + await postgres_connection.execute( + "INSERT INTO {schema}.posthog_test (id) VALUES (1)".format(schema=postgres_config["schema"]) + ) + await postgres_connection.commit() + async def setup_job_1(): new_source = await sync_to_async(ExternalDataSource.objects.create)( source_id=uuid.uuid4(), @@ -503,12 +541,12 @@ async def setup_job_1(): status="running", source_type="Postgres", job_inputs={ - "host": settings.PG_HOST, - "port": int(settings.PG_PORT), - "database": settings.PG_DATABASE, - "user": settings.PG_USER, - "password": settings.PG_PASSWORD, - "schema": "public", + "host": postgres_config["host"], + "port": postgres_config["port"], + "database": postgres_config["database"], + "user": postgres_config["user"], + "password": postgres_config["password"], + "schema": postgres_config["schema"], }, ) # type: ignore @@ -521,7 +559,7 @@ async def setup_job_1(): new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)() # type: ignore - schemas = ["posthog_team"] + schemas = ["posthog_test"] inputs = ExternalDataJobInputs( team_id=team.id, run_id=new_job.pk, @@ -543,6 +581,6 @@ async def setup_job_1(): ) job_1_team_objects = await minio_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/posthog_team/" + Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/posthog_test/" ) assert len(job_1_team_objects["Contents"]) == 1 diff --git a/posthog/warehouse/api/test/test_external_data_source.py b/posthog/warehouse/api/test/test_external_data_source.py index 43bce4e965336..d0aed693e7d59 100644 --- a/posthog/warehouse/api/test/test_external_data_source.py +++ b/posthog/warehouse/api/test/test_external_data_source.py @@ -146,7 +146,7 @@ def test_database_schema(self): data={ "host": settings.PG_HOST, "port": int(settings.PG_PORT), - "database": settings.PG_DATABASE, + "dbname": settings.PG_DATABASE, "user": settings.PG_USER, "password": settings.PG_PASSWORD, "schema": "public", diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index 4f882ff97ed72..8a4ac00e81416 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -65,4 +65,6 @@ def get_postgres_schemas(host: str, port: str, database: str, user: str, passwor result = cursor.fetchall() result = [row[0] for row in result] + connection.close() + return result