diff --git a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py index c4fbf956f1217..f23f07c5cd767 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py @@ -25,6 +25,71 @@ get_primary_key, SqlDatabaseTableConfiguration, ) +from dlt.common.data_types.typing import TDataType + +POSTGRES_TO_DLT_TYPES: dict[str, TDataType] = { + # Text types + "char": "text", + "character": "text", + "varchar": "text", + "character varying": "text", + "text": "text", + "xml": "text", + "uuid": "text", + "cidr": "text", + "inet": "text", + "macaddr": "text", + "macaddr8": "text", + "tsvector": "text", + "tsquery": "text", + # Bigint types + "bigint": "bigint", + "bigserial": "bigint", + # Boolean type + "boolean": "bool", + # Timestamp types + "timestamp": "timestamp", + "timestamp with time zone": "timestamp", + "timestamp without time zone": "timestamp", + # Complex types (geometric types, ranges, json, etc.) + "point": "complex", + "line": "complex", + "lseg": "complex", + "box": "complex", + "path": "complex", + "polygon": "complex", + "circle": "complex", + "int4range": "complex", + "int8range": "complex", + "numrange": "complex", + "tsrange": "complex", + "tstzrange": "complex", + "daterange": "complex", + "json": "complex", + "jsonb": "complex", + # Decimal types + "real": "decimal", + "double precision": "decimal", + "numeric": "decimal", + "decimal": "decimal", + # Date type + "date": "date", + # Additional mappings + "smallint": "bigint", + "integer": "bigint", + "serial": "bigint", + "money": "decimal", + "bytea": "text", + "time": "timestamp", + "time with time zone": "timestamp", + "time without time zone": "timestamp", + "interval": "complex", + "bit": "text", + "bit varying": "text", + "enum": "text", + "oid": "bigint", + "pg_lsn": "text", +} def incremental_type_to_initial_value(field_type: IncrementalFieldType) -> Any: @@ -262,7 +327,7 @@ def get_column_hints(engine: Engine, schema_name: str, table_name: str) -> dict[ with engine.connect() as conn: execute_result: CursorResult = conn.execute( text( - "SELECT column_name, data_type, numeric_precision, numeric_scale FROM information_schema.columns WHERE table_schema = :schema_name AND table_name = :table_name" + "SELECT column_name, data_type, numeric_precision, numeric_scale, is_nullable FROM information_schema.columns WHERE table_schema = :schema_name AND table_name = :table_name" ), {"schema_name": schema_name, "table_name": table_name}, ) @@ -272,14 +337,19 @@ def get_column_hints(engine: Engine, schema_name: str, table_name: str) -> dict[ columns: dict[str, TColumnSchema] = {} - for column_name, data_type, numeric_precision, numeric_scale in results: - if data_type != "numeric": - continue - - columns[column_name] = { - "data_type": "decimal", - "precision": numeric_precision or 76, - "scale": numeric_scale or 32, - } + for column_name, data_type, numeric_precision, numeric_scale, is_nullable in results: + if data_type == "numeric": + columns[column_name] = { + "data_type": "decimal", + "precision": numeric_precision or 76, + "scale": numeric_scale or 32, + "nullable": is_nullable == "YES", + } + else: + columns[column_name] = { + "name": column_name, + "data_type": POSTGRES_TO_DLT_TYPES.get(data_type, "text"), + "nullable": is_nullable == "YES", + } return columns diff --git a/posthog/temporal/data_imports/pipelines/sql_database/test/test_sql_database.py b/posthog/temporal/data_imports/pipelines/sql_database/test/test_sql_database.py index edf217c4a67a4..a136b2ef7af96 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/test/test_sql_database.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/test/test_sql_database.py @@ -24,22 +24,24 @@ def test_get_column_hints_numeric_no_results(): def test_get_column_hints_numeric_with_scale_and_precision(): - mock_engine = _setup([("column", "numeric", 10, 2)]) + mock_engine = _setup([("column", "numeric", 10, 2, "NO")]) assert get_column_hints(mock_engine, "some_schema", "some_table") == { - "column": {"data_type": "decimal", "precision": 10, "scale": 2} + "column": {"data_type": "decimal", "precision": 10, "scale": 2, "nullable": False} } def test_get_column_hints_numeric_with_missing_scale_and_precision(): - mock_engine = _setup([("column", "numeric", None, None)]) + mock_engine = _setup([("column", "numeric", None, None, "NO")]) assert get_column_hints(mock_engine, "some_schema", "some_table") == { - "column": {"data_type": "decimal", "precision": 76, "scale": 32} + "column": {"data_type": "decimal", "precision": 76, "scale": 32, "nullable": False} } def test_get_column_hints_numeric_with_no_numeric(): - mock_engine = _setup([("column", "bigint", None, None)]) + mock_engine = _setup([("column", "bigint", None, None, "NO")]) - assert get_column_hints(mock_engine, "some_schema", "some_table") == {} + assert get_column_hints(mock_engine, "some_schema", "some_table") == { + "column": {"name": "column", "data_type": "bigint", "nullable": False} + } diff --git a/posthog/temporal/tests/external_data/test_external_data_job.py b/posthog/temporal/tests/external_data/test_external_data_job.py index 93630571c3a7a..a1ff942cc7448 100644 --- a/posthog/temporal/tests/external_data/test_external_data_job.py +++ b/posthog/temporal/tests/external_data/test_external_data_job.py @@ -809,9 +809,104 @@ def mock_to_object_store_rs_credentials(class_self): job_1_team_objects = await minio_client.list_objects_v2( Bucket=BUCKET_NAME, Prefix=f"{folder_path}/posthog_test/" ) + assert len(job_1_team_objects["Contents"]) == 2 +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_run_postgres_job_empty_table( + activity_environment, team, minio_client, postgres_connection, postgres_config, **kwargs +): + await postgres_connection.execute( + "CREATE TABLE IF NOT EXISTS {schema}.posthog_test (id integer)".format(schema=postgres_config["schema"]) + ) + await postgres_connection.commit() + + async def setup_job_1(): + new_source = await sync_to_async(ExternalDataSource.objects.create)( + source_id=uuid.uuid4(), + connection_id=uuid.uuid4(), + destination_id=uuid.uuid4(), + team=team, + status="running", + source_type="Postgres", + job_inputs={ + "host": postgres_config["host"], + "port": postgres_config["port"], + "database": postgres_config["database"], + "user": postgres_config["user"], + "password": postgres_config["password"], + "schema": postgres_config["schema"], + "ssh_tunnel_enabled": False, + }, + ) + + posthog_test_schema = await _create_schema("posthog_test", new_source, team) + + new_job: ExternalDataJob = await sync_to_async(ExternalDataJob.objects.create)( + team_id=team.id, + pipeline_id=new_source.pk, + status=ExternalDataJob.Status.RUNNING, + rows_synced=0, + schema=posthog_test_schema, + ) + + new_job = await sync_to_async( + ExternalDataJob.objects.filter(id=new_job.id).prefetch_related("pipeline").prefetch_related("schema").get + )() + + inputs = ImportDataActivityInputs( + team_id=team.id, run_id=str(new_job.pk), source_id=new_source.pk, schema_id=posthog_test_schema.id + ) + + return new_job, inputs + + job_1, job_1_inputs = await setup_job_1() + + def mock_to_session_credentials(class_self): + return { + "aws_access_key_id": settings.OBJECT_STORAGE_ACCESS_KEY_ID, + "aws_secret_access_key": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + "endpoint_url": settings.OBJECT_STORAGE_ENDPOINT, + "aws_session_token": None, + "AWS_ALLOW_HTTP": "true", + "AWS_S3_ALLOW_UNSAFE_RENAME": "true", + } + + def mock_to_object_store_rs_credentials(class_self): + return { + "aws_access_key_id": settings.OBJECT_STORAGE_ACCESS_KEY_ID, + "aws_secret_access_key": settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + "endpoint_url": settings.OBJECT_STORAGE_ENDPOINT, + "region": "us-east-1", + "AWS_ALLOW_HTTP": "true", + "AWS_S3_ALLOW_UNSAFE_RENAME": "true", + } + + with ( + override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + AIRBYTE_BUCKET_REGION="us-east-1", + AIRBYTE_BUCKET_DOMAIN="objectstorage:19000", + BUCKET_NAME=BUCKET_NAME, + ), + mock.patch.object(AwsCredentials, "to_session_credentials", mock_to_session_credentials), + mock.patch.object(AwsCredentials, "to_object_store_rs_credentials", mock_to_object_store_rs_credentials), + ): + await asyncio.gather( + activity_environment.run(import_data_activity, job_1_inputs), + ) + + folder_path = await sync_to_async(job_1.folder_path)() + job_1_team_objects = await minio_client.list_objects_v2( + Bucket=BUCKET_NAME, Prefix=f"{folder_path}/posthog_test/" + ) + assert len(job_1_team_objects["Contents"]) == 3 + + @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_check_schedule_activity_with_schema_id(activity_environment, team, **kwargs): diff --git a/requirements.in b/requirements.in index 4fef89511c686..087277f406285 100644 --- a/requirements.in +++ b/requirements.in @@ -50,7 +50,7 @@ langsmith==0.1.106 lzstring==1.0.4 natsort==8.4.0 nanoid==2.0.0 -numpy==1.23.3 +numpy==1.26.0 openpyxl==3.1.2 orjson==3.10.7 pandas==2.2.0 diff --git a/requirements.txt b/requirements.txt index d7ed441fe2c61..dfadc62f31656 100644 --- a/requirements.txt +++ b/requirements.txt @@ -381,7 +381,7 @@ natsort==8.4.0 # via -r requirements.in nh3==0.2.14 # via -r requirements.in -numpy==1.23.3 +numpy==1.26.0 # via # -r requirements.in # langchain