diff --git a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py index 3438db67a941f..700e3af65b99e 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py @@ -1,14 +1,15 @@ """Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads.""" from datetime import datetime, date -from typing import Any, Optional, Union, List # noqa: UP035 +from typing import Any, Optional, Union, List, cast # noqa: UP035 from collections.abc import Iterable from zoneinfo import ZoneInfo from sqlalchemy import MetaData, Table -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, CursorResult import dlt from dlt.sources import DltResource, DltSource +from dlt.common.schema.typing import TColumnSchema from dlt.sources.credentials import ConnectionStringCredentials @@ -139,8 +140,37 @@ def sql_database( write_disposition="merge" if incremental else "replace", spec=SqlDatabaseTableConfiguration, table_format="delta", + columns=get_column_hints(engine, schema or "", table.name), )( engine=engine, table=table, incremental=incremental, ) + + +def get_column_hints(engine: Engine, schema_name: str, table_name: str) -> dict[str, TColumnSchema]: + with engine.connect() as conn: + execute_result: CursorResult | None = conn.execute( + "SELECT column_name, data_type, numeric_precision, numeric_scale FROM information_schema.columns WHERE table_schema = %(schema_name)s AND table_name = %(table_name)s", + {"schema_name": schema_name, "table_name": table_name}, + ) + + if execute_result is None: + return {} + + cursor_result = cast(CursorResult, execute_result) + results = cursor_result.fetchall() + + columns: dict[str, TColumnSchema] = {} + + for column_name, data_type, numeric_precision, numeric_scale in results: + if data_type != "numeric": + continue + + columns[column_name] = { + "data_type": "decimal", + "precision": numeric_precision or 76, + "scale": numeric_scale or 16, + } + + return columns diff --git a/posthog/temporal/data_imports/pipelines/sql_database/test/test_sql_database.py b/posthog/temporal/data_imports/pipelines/sql_database/test/test_sql_database.py new file mode 100644 index 0000000000000..d604d1e38c35b --- /dev/null +++ b/posthog/temporal/data_imports/pipelines/sql_database/test/test_sql_database.py @@ -0,0 +1,45 @@ +from unittest.mock import MagicMock + +from posthog.temporal.data_imports.pipelines.sql_database import get_column_hints + + +def _setup(return_value): + mock_engine = MagicMock() + mock_engine_enter = MagicMock() + mock_connection = MagicMock() + mock_result = MagicMock() + + mock_engine.configure_mock(**{"connect.return_value": mock_engine_enter}) + mock_engine_enter.configure_mock(**{"__enter__.return_value": mock_connection}) + mock_connection.configure_mock(**{"execute.return_value": mock_result}) + mock_result.configure_mock(**{"fetchall.return_value": return_value}) + + return mock_engine + + +def test_get_column_hints_numeric_no_results(): + mock_engine = _setup([]) + + assert get_column_hints(mock_engine, "some_schema", "some_table") == {} + + +def test_get_column_hints_numeric_with_scale_and_precision(): + mock_engine = _setup([("column", "numeric", 10, 2)]) + + assert get_column_hints(mock_engine, "some_schema", "some_table") == { + "column": {"data_type": "decimal", "precision": 10, "scale": 2} + } + + +def test_get_column_hints_numeric_with_missing_scale_and_precision(): + mock_engine = _setup([("column", "numeric", None, None)]) + + assert get_column_hints(mock_engine, "some_schema", "some_table") == { + "column": {"data_type": "decimal", "precision": 76, "scale": 16} + } + + +def test_get_column_hints_numeric_with_no_numeric(): + mock_engine = _setup([("column", "bigint", None, None)]) + + assert get_column_hints(mock_engine, "some_schema", "some_table") == {}