Skip to content

Commit

Permalink
fix(data-warehouse): Ensure we're using the correct precision/scale n…
Browse files Browse the repository at this point in the history
…umeric values (#23913)
  • Loading branch information
Gilbert09 authored Jul 23, 2024
1 parent 5dd0781 commit fc835c6
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
34 changes: 32 additions & 2 deletions posthog/temporal/data_imports/pipelines/sql_database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads."""

from datetime import datetime, date
from typing import Any, Optional, Union, List # noqa: UP035
from typing import Any, Optional, Union, List, cast # noqa: UP035
from collections.abc import Iterable
from zoneinfo import ZoneInfo
from sqlalchemy import MetaData, Table
from sqlalchemy.engine import Engine
from sqlalchemy.engine import Engine, CursorResult

import dlt
from dlt.sources import DltResource, DltSource
from dlt.common.schema.typing import TColumnSchema


from dlt.sources.credentials import ConnectionStringCredentials
Expand Down Expand Up @@ -139,8 +140,37 @@ def sql_database(
write_disposition="merge" if incremental else "replace",
spec=SqlDatabaseTableConfiguration,
table_format="delta",
columns=get_column_hints(engine, schema or "", table.name),
)(
engine=engine,
table=table,
incremental=incremental,
)


def get_column_hints(engine: Engine, schema_name: str, table_name: str) -> dict[str, TColumnSchema]:
with engine.connect() as conn:
execute_result: CursorResult | None = conn.execute(
"SELECT column_name, data_type, numeric_precision, numeric_scale FROM information_schema.columns WHERE table_schema = %(schema_name)s AND table_name = %(table_name)s",
{"schema_name": schema_name, "table_name": table_name},
)

if execute_result is None:
return {}

cursor_result = cast(CursorResult, execute_result)
results = cursor_result.fetchall()

columns: dict[str, TColumnSchema] = {}

for column_name, data_type, numeric_precision, numeric_scale in results:
if data_type != "numeric":
continue

columns[column_name] = {
"data_type": "decimal",
"precision": numeric_precision or 76,
"scale": numeric_scale or 16,
}

return columns
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from unittest.mock import MagicMock

from posthog.temporal.data_imports.pipelines.sql_database import get_column_hints


def _setup(return_value):
mock_engine = MagicMock()
mock_engine_enter = MagicMock()
mock_connection = MagicMock()
mock_result = MagicMock()

mock_engine.configure_mock(**{"connect.return_value": mock_engine_enter})
mock_engine_enter.configure_mock(**{"__enter__.return_value": mock_connection})
mock_connection.configure_mock(**{"execute.return_value": mock_result})
mock_result.configure_mock(**{"fetchall.return_value": return_value})

return mock_engine


def test_get_column_hints_numeric_no_results():
mock_engine = _setup([])

assert get_column_hints(mock_engine, "some_schema", "some_table") == {}


def test_get_column_hints_numeric_with_scale_and_precision():
mock_engine = _setup([("column", "numeric", 10, 2)])

assert get_column_hints(mock_engine, "some_schema", "some_table") == {
"column": {"data_type": "decimal", "precision": 10, "scale": 2}
}


def test_get_column_hints_numeric_with_missing_scale_and_precision():
mock_engine = _setup([("column", "numeric", None, None)])

assert get_column_hints(mock_engine, "some_schema", "some_table") == {
"column": {"data_type": "decimal", "precision": 76, "scale": 16}
}


def test_get_column_hints_numeric_with_no_numeric():
mock_engine = _setup([("column", "bigint", None, None)])

assert get_column_hints(mock_engine, "some_schema", "some_table") == {}

0 comments on commit fc835c6

Please sign in to comment.