Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
EDsCODE committed Jan 26, 2024
1 parent 56c91b9 commit fe5eb46
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 45 deletions.
36 changes: 1 addition & 35 deletions posthog/temporal/data_imports/pipelines/postgres/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads."""

from typing import List, Optional, Union, Iterable, Any
from sqlalchemy import MetaData, Table
from sqlalchemy import MetaData, Table, text
from sqlalchemy.engine import Engine

import dlt
Expand Down Expand Up @@ -70,37 +70,3 @@ def sql_database(
primary_key=get_primary_key(table),
spec=SqlDatabaseTableConfiguration,
)(engine, table)


@dlt.common.configuration.with_config(sections=("sources", "sql_database"), spec=SqlTableResourceConfiguration)
def sql_table(
credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value,
table: str = dlt.config.value,
schema: Optional[str] = dlt.config.value,
metadata: Optional[MetaData] = None,
incremental: Optional[dlt.sources.incremental[Any]] = None,
) -> DltResource:
"""
A dlt resource which loads data from an SQL database table using SQLAlchemy.
Args:
credentials (Union[ConnectionStringCredentials, Engine, str]): Database credentials or an `Engine` instance representing the database connection.
table (str): Name of the table to load.
schema (Optional[str]): Optional name of the schema the table belongs to.
metadata (Optional[MetaData]): Optional `sqlalchemy.MetaData` instance. If provided, the `schema` argument is ignored.
incremental (Optional[dlt.sources.incremental[Any]]): Option to enable incremental loading for the table.
E.g., `incremental=dlt.sources.incremental('updated_at', pendulum.parse('2022-01-01T00:00:00Z'))`
write_disposition (str): Write disposition of the resource.
Returns:
DltResource: The dlt resource for loading data from the SQL database table.
"""
engine = engine_from_credentials(credentials)
engine.execution_options(stream_results=True)
metadata = metadata or MetaData(schema=schema)

table_obj = Table(table, metadata, autoload_with=engine)

return dlt.resource(table_rows, name=table_obj.name, primary_key=get_primary_key(table_obj))(
engine, table_obj, incremental=incremental
)
2 changes: 2 additions & 0 deletions posthog/temporal/data_imports/pipelines/postgres/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def table_rows(
loader = TableLoader(engine, table, incremental=incremental, chunk_size=chunk_size)
yield from loader.load_rows()

engine.dispose()


def engine_from_credentials(credentials: Union[ConnectionStringCredentials, Engine, str]) -> Engine:
if isinstance(credentials, Engine):
Expand Down
71 changes: 71 additions & 0 deletions posthog/temporal/tests/external_data/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import psycopg
import pytest_asyncio
from psycopg import sql


@pytest_asyncio.fixture
async def setup_postgres_test_db(postgres_config):
"""Fixture to manage a database for Redshift export testing.
Managing a test database involves the following steps:
1. Creating a test database.
2. Initializing a connection to that database.
3. Creating a test schema.
4. Yielding the connection to be used in tests.
5. After tests, drop the test schema and any tables in it.
6. Drop the test database.
"""
connection = await psycopg.AsyncConnection.connect(
user=postgres_config["user"],
password=postgres_config["password"],
host=postgres_config["host"],
port=postgres_config["port"],
)
await connection.set_autocommit(True)

async with connection.cursor() as cursor:
await cursor.execute(
sql.SQL("SELECT 1 FROM pg_database WHERE datname = %s"),
(postgres_config["database"],),
)

if await cursor.fetchone() is None:
await cursor.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(postgres_config["database"])))

await connection.close()

# We need a new connection to connect to the database we just created.
connection = await psycopg.AsyncConnection.connect(
user=postgres_config["user"],
password=postgres_config["password"],
host=postgres_config["host"],
port=postgres_config["port"],
dbname=postgres_config["database"],
)
await connection.set_autocommit(True)

async with connection.cursor() as cursor:
await cursor.execute(
sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(sql.Identifier(postgres_config["schema"]))
)

yield

async with connection.cursor() as cursor:
await cursor.execute(sql.SQL("DROP SCHEMA {} CASCADE").format(sql.Identifier(postgres_config["schema"])))

await connection.close()

# We need a new connection to drop the database, as we cannot drop the current database.
connection = await psycopg.AsyncConnection.connect(
user=postgres_config["user"],
password=postgres_config["password"],
host=postgres_config["host"],
port=postgres_config["port"],
)
await connection.set_autocommit(True)

async with connection.cursor() as cursor:
await cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"])))

await connection.close()
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import functools
from django.conf import settings
import asyncio
import psycopg

BUCKET_NAME = "test-external-data-jobs"
SESSION = aioboto3.Session()
Expand Down Expand Up @@ -89,6 +90,33 @@ async def minio_client(bucket_name):
await minio_client.delete_bucket(Bucket=bucket_name)


@pytest.fixture
def postgres_config():
return {
"user": settings.PG_USER,
"password": settings.PG_PASSWORD,
"database": "external_data_database",
"schema": "external_data_schema",
"host": settings.PG_HOST,
"port": int(settings.PG_PORT),
}


@pytest_asyncio.fixture
async def postgres_connection(postgres_config, setup_postgres_test_db):
connection = await psycopg.AsyncConnection.connect(
user=postgres_config["user"],
password=postgres_config["password"],
dbname=postgres_config["database"],
host=postgres_config["host"],
port=postgres_config["port"],
)

yield connection

await connection.close()


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_create_external_job_activity(activity_environment, team, **kwargs):
Expand Down Expand Up @@ -493,7 +521,17 @@ async def mock_async_func(inputs):

@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_run_postgres_job(activity_environment, team, minio_client, **kwargs):
async def test_run_postgres_job(
activity_environment, team, minio_client, postgres_connection, postgres_config, **kwargs
):
await postgres_connection.execute(
"CREATE TABLE IF NOT EXISTS {schema}.posthog_test (id integer)".format(schema=postgres_config["schema"])
)
await postgres_connection.execute(
"INSERT INTO {schema}.posthog_test (id) VALUES (1)".format(schema=postgres_config["schema"])
)
await postgres_connection.commit()

async def setup_job_1():
new_source = await sync_to_async(ExternalDataSource.objects.create)(
source_id=uuid.uuid4(),
Expand All @@ -503,12 +541,12 @@ async def setup_job_1():
status="running",
source_type="Postgres",
job_inputs={
"host": settings.PG_HOST,
"port": int(settings.PG_PORT),
"database": settings.PG_DATABASE,
"user": settings.PG_USER,
"password": settings.PG_PASSWORD,
"schema": "public",
"host": postgres_config["host"],
"port": postgres_config["port"],
"database": postgres_config["database"],
"user": postgres_config["user"],
"password": postgres_config["password"],
"schema": postgres_config["schema"],
},
) # type: ignore

Expand All @@ -521,7 +559,7 @@ async def setup_job_1():

new_job = await sync_to_async(ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").get)() # type: ignore

schemas = ["posthog_team"]
schemas = ["posthog_test"]
inputs = ExternalDataJobInputs(
team_id=team.id,
run_id=new_job.pk,
Expand All @@ -543,6 +581,6 @@ async def setup_job_1():
)

job_1_team_objects = await minio_client.list_objects_v2(
Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/posthog_team/"
Bucket=BUCKET_NAME, Prefix=f"{job_1.folder_path}/posthog_test/"
)
assert len(job_1_team_objects["Contents"]) == 1
2 changes: 1 addition & 1 deletion posthog/warehouse/api/test/test_external_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_database_schema(self):
data={
"host": settings.PG_HOST,
"port": int(settings.PG_PORT),
"database": settings.PG_DATABASE,
"dbname": settings.PG_DATABASE,
"user": settings.PG_USER,
"password": settings.PG_PASSWORD,
"schema": "public",
Expand Down
2 changes: 2 additions & 0 deletions posthog/warehouse/models/external_data_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,6 @@ def get_postgres_schemas(host: str, port: str, database: str, user: str, passwor
result = cursor.fetchall()
result = [row[0] for row in result]

connection.close()

return result

0 comments on commit fe5eb46

Please sign in to comment.