diff --git a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py index 96bfa8a9d202d..962cbb2d4ad9b 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py @@ -65,6 +65,8 @@ def sql_source_for_type( else: incremental = None + connect_args = [] + if source_type == ExternalDataSource.Type.POSTGRES: credentials = ConnectionStringCredentials( f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={sslmode}" @@ -76,6 +78,10 @@ def sql_source_for_type( credentials = ConnectionStringCredentials( f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?ssl_ca={ssl_ca}&ssl_verify_cert=false" ) + + # PlanetScale needs this to be set + if host.endswith("psdb.cloud"): + connect_args = ["SET workload = 'OLAP';"] elif source_type == ExternalDataSource.Type.MSSQL: credentials = ConnectionStringCredentials( f"mssql+pyodbc://{user}:{password}@{host}:{port}/{database}?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes" @@ -84,7 +90,12 @@ def sql_source_for_type( raise Exception("Unsupported source_type") db_source = sql_database( - credentials, schema=schema, table_names=table_names, incremental=incremental, team_id=team_id + credentials, + schema=schema, + table_names=table_names, + incremental=incremental, + team_id=team_id, + connect_args=connect_args, ) return db_source @@ -180,6 +191,7 @@ def sql_database( table_names: Optional[List[str]] = dlt.config.value, # noqa: UP006 incremental: Optional[dlt.sources.incremental] = None, team_id: Optional[int] = None, + connect_args: Optional[list[str]] = None, ) -> Iterable[DltResource]: """ A DLT source which loads data from an SQL database using SQLAlchemy. @@ -231,6 +243,7 @@ def sql_database( engine=engine, table=table, incremental=incremental, + connect_args=connect_args, ) ) diff --git a/posthog/temporal/data_imports/pipelines/sql_database/helpers.py b/posthog/temporal/data_imports/pipelines/sql_database/helpers.py index d877effb3e374..50577b6b04d17 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/helpers.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/helpers.py @@ -14,7 +14,7 @@ from dlt.common.typing import TDataItem from .settings import DEFAULT_CHUNK_SIZE -from sqlalchemy import Table, create_engine, Column +from sqlalchemy import Table, create_engine, Column, text from sqlalchemy.engine import Engine from sqlalchemy.sql import Select @@ -26,11 +26,13 @@ def __init__( table: Table, chunk_size: int = 1000, incremental: Optional[dlt.sources.incremental[Any]] = None, + connect_args: Optional[list[str]] = None, ) -> None: self.engine = engine self.table = table self.chunk_size = chunk_size self.incremental = incremental + self.connect_args = connect_args if incremental: try: self.cursor_column: Optional[Column[Any]] = table.c[incremental.cursor_path] @@ -74,6 +76,9 @@ def make_query(self) -> Select[Any]: def load_rows(self) -> Iterator[list[TDataItem]]: query = self.make_query() with self.engine.connect() as conn: + if self.connect_args: + for stmt in self.connect_args: + conn.execute(text(stmt)) result = conn.execution_options(yield_per=self.chunk_size).execute(query) for partition in result.partitions(size=self.chunk_size): yield [dict(row._mapping) for row in partition] @@ -84,6 +89,7 @@ def table_rows( table: Table, chunk_size: int = DEFAULT_CHUNK_SIZE, incremental: Optional[dlt.sources.incremental[Any]] = None, + connect_args: Optional[list[str]] = None, ) -> Iterator[TDataItem]: """ A DLT source which loads data from an SQL database using SQLAlchemy. @@ -100,7 +106,7 @@ def table_rows( """ yield dlt.mark.materialize_table_schema() # type: ignore - loader = TableLoader(engine, table, incremental=incremental, chunk_size=chunk_size) + loader = TableLoader(engine, table, incremental=incremental, chunk_size=chunk_size, connect_args=connect_args) yield from loader.load_rows() engine.dispose()