Skip to content

Commit

Permalink
chore(data-warehouse): add connect_args to sql import (#25637)
Browse files Browse the repository at this point in the history
  • Loading branch information
EDsCODE authored Oct 17, 2024
1 parent 423f893 commit f645a4b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
15 changes: 14 additions & 1 deletion posthog/temporal/data_imports/pipelines/sql_database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def sql_source_for_type(
else:
incremental = None

connect_args = []

if source_type == ExternalDataSource.Type.POSTGRES:
credentials = ConnectionStringCredentials(
f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={sslmode}"
Expand All @@ -76,6 +78,10 @@ def sql_source_for_type(
credentials = ConnectionStringCredentials(
f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?ssl_ca={ssl_ca}&ssl_verify_cert=false"
)

# PlanetScale needs this to be set
if host.endswith("psdb.cloud"):
connect_args = ["SET workload = 'OLAP';"]
elif source_type == ExternalDataSource.Type.MSSQL:
credentials = ConnectionStringCredentials(
f"mssql+pyodbc://{user}:{password}@{host}:{port}/{database}?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
Expand All @@ -84,7 +90,12 @@ def sql_source_for_type(
raise Exception("Unsupported source_type")

db_source = sql_database(
credentials, schema=schema, table_names=table_names, incremental=incremental, team_id=team_id
credentials,
schema=schema,
table_names=table_names,
incremental=incremental,
team_id=team_id,
connect_args=connect_args,
)

return db_source
Expand Down Expand Up @@ -180,6 +191,7 @@ def sql_database(
table_names: Optional[List[str]] = dlt.config.value, # noqa: UP006
incremental: Optional[dlt.sources.incremental] = None,
team_id: Optional[int] = None,
connect_args: Optional[list[str]] = None,
) -> Iterable[DltResource]:
"""
A DLT source which loads data from an SQL database using SQLAlchemy.
Expand Down Expand Up @@ -231,6 +243,7 @@ def sql_database(
engine=engine,
table=table,
incremental=incremental,
connect_args=connect_args,
)
)

Expand Down
10 changes: 8 additions & 2 deletions posthog/temporal/data_imports/pipelines/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dlt.common.typing import TDataItem
from .settings import DEFAULT_CHUNK_SIZE

from sqlalchemy import Table, create_engine, Column
from sqlalchemy import Table, create_engine, Column, text
from sqlalchemy.engine import Engine
from sqlalchemy.sql import Select

Expand All @@ -26,11 +26,13 @@ def __init__(
table: Table,
chunk_size: int = 1000,
incremental: Optional[dlt.sources.incremental[Any]] = None,
connect_args: Optional[list[str]] = None,
) -> None:
self.engine = engine
self.table = table
self.chunk_size = chunk_size
self.incremental = incremental
self.connect_args = connect_args
if incremental:
try:
self.cursor_column: Optional[Column[Any]] = table.c[incremental.cursor_path]
Expand Down Expand Up @@ -74,6 +76,9 @@ def make_query(self) -> Select[Any]:
def load_rows(self) -> Iterator[list[TDataItem]]:
query = self.make_query()
with self.engine.connect() as conn:
if self.connect_args:
for stmt in self.connect_args:
conn.execute(text(stmt))
result = conn.execution_options(yield_per=self.chunk_size).execute(query)
for partition in result.partitions(size=self.chunk_size):
yield [dict(row._mapping) for row in partition]
Expand All @@ -84,6 +89,7 @@ def table_rows(
table: Table,
chunk_size: int = DEFAULT_CHUNK_SIZE,
incremental: Optional[dlt.sources.incremental[Any]] = None,
connect_args: Optional[list[str]] = None,
) -> Iterator[TDataItem]:
"""
A DLT source which loads data from an SQL database using SQLAlchemy.
Expand All @@ -100,7 +106,7 @@ def table_rows(
"""
yield dlt.mark.materialize_table_schema() # type: ignore

loader = TableLoader(engine, table, incremental=incremental, chunk_size=chunk_size)
loader = TableLoader(engine, table, incremental=incremental, chunk_size=chunk_size, connect_args=connect_args)
yield from loader.load_rows()

engine.dispose()
Expand Down

0 comments on commit f645a4b

Please sign in to comment.