diff --git a/frontend/public/services/mysql.png b/frontend/public/services/mysql.png new file mode 100644 index 00000000000000..923f456fd3fbdb Binary files /dev/null and b/frontend/public/services/mysql.png differ diff --git a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx index 8aad2919fc5160..49d93f36b4b850 100644 --- a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx +++ b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx @@ -198,6 +198,130 @@ export const SOURCE_DETAILS: Record = { }, ], }, + MySQL: { + name: 'MySQL', + caption: ( + <> + Enter your MySQL/MariaDB credentials to automatically pull your MySQL data into the PostHog Data + warehouse. + + ), + fields: [ + { + name: 'host', + label: 'Host', + type: 'text', + required: true, + placeholder: 'localhost', + }, + { + name: 'port', + label: 'Port', + type: 'number', + required: true, + placeholder: '3306', + }, + { + name: 'dbname', + label: 'Database', + type: 'text', + required: true, + placeholder: 'mysql', + }, + { + name: 'user', + label: 'User', + type: 'text', + required: true, + placeholder: 'mysql', + }, + { + name: 'password', + label: 'Password', + type: 'password', + required: true, + placeholder: '', + }, + { + name: 'schema', + label: 'Schema', + type: 'text', + required: true, + placeholder: 'public', + }, + { + name: 'ssh-tunnel', + label: 'Use SSH tunnel?', + type: 'switch-group', + default: false, + fields: [ + { + name: 'host', + label: 'Tunnel host', + type: 'text', + required: true, + placeholder: 'localhost', + }, + { + name: 'port', + label: 'Tunnel port', + type: 'number', + required: true, + placeholder: '22', + }, + { + type: 'select', + name: 'auth_type', + label: 'Authentication type', + required: true, + defaultValue: 'password', + options: [ + { + label: 'Password', + value: 'password', + fields: [ + { + name: 'username', + label: 'Tunnel username', + type: 'text', + required: true, + placeholder: 'User1', + }, + { + name: 'password', + label: 'Tunnel password', + type: 'password', + required: true, + placeholder: '', + }, + ], + }, + { + label: 'Key pair', + value: 'keypair', + fields: [ + { + name: 'private_key', + label: 'Tunnel private key', + type: 'textarea', + required: true, + placeholder: '', + }, + { + name: 'passphrase', + label: 'Tunnel passphrase', + type: 'password', + required: false, + placeholder: '', + }, + ], + }, + ], + }, + ], + }, + ], + }, Snowflake: { name: 'Snowflake', caption: ( diff --git a/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx b/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx index 51153885c7bbbc..0c97d1bd088f84 100644 --- a/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx +++ b/frontend/src/scenes/data-warehouse/settings/DataWarehouseManagedSourcesTable.tsx @@ -10,6 +10,7 @@ import Iconazure from 'public/services/azure.png' import IconCloudflare from 'public/services/cloudflare.png' import IconGoogleCloudStorage from 'public/services/google-cloud-storage.png' import IconHubspot from 'public/services/hubspot.png' +import IconMySQL from 'public/services/mysql.png' import IconPostgres from 'public/services/postgres.png' import IconSnowflake from 'public/services/snowflake.png' import IconStripe from 'public/services/stripe.png' @@ -187,6 +188,7 @@ export function RenderDataWarehouseSourceIcon({ Hubspot: IconHubspot, Zendesk: IconZendesk, Postgres: IconPostgres, + MySQL: IconMySQL, Snowflake: IconSnowflake, aws: IconAwsS3, 'google-cloud': IconGoogleCloudStorage, diff --git a/frontend/src/types.ts b/frontend/src/types.ts index ceede7bd84355e..b6cefa0036249d 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -3828,7 +3828,7 @@ export enum DataWarehouseSettingsTab { SelfManaged = 'self-managed', } -export const externalDataSources = ['Stripe', 'Hubspot', 'Postgres', 'Zendesk', 'Snowflake'] as const +export const externalDataSources = ['Stripe', 'Hubspot', 'Postgres', 'MySQL', 'Zendesk', 'Snowflake'] as const export type ExternalDataSourceType = (typeof externalDataSources)[number] diff --git a/latest_migrations.manifest b/latest_migrations.manifest index d9d360ae6dc946..a446e72c619689 100644 --- a/latest_migrations.manifest +++ b/latest_migrations.manifest @@ -5,7 +5,7 @@ contenttypes: 0002_remove_content_type_name ee: 0016_rolemembership_organization_member otp_static: 0002_throttling otp_totp: 0002_auto_20190420_0723 -posthog: 0447_alter_integration_kind +posthog: 0448_add_mysql_externaldatasource_source_type sessions: 0001_initial social_django: 0010_uid_db_index two_factor: 0007_auto_20201201_1019 diff --git a/posthog/migrations/0448_add_mysql_externaldatasource_source_type.py b/posthog/migrations/0448_add_mysql_externaldatasource_source_type.py new file mode 100644 index 00000000000000..b1cc746856069a --- /dev/null +++ b/posthog/migrations/0448_add_mysql_externaldatasource_source_type.py @@ -0,0 +1,27 @@ +# Generated by Django 4.2.11 on 2024-06-05 17:12 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("posthog", "0447_alter_integration_kind"), + ] + + operations = [ + migrations.AlterField( + model_name="externaldatasource", + name="source_type", + field=models.CharField( + choices=[ + ("Stripe", "Stripe"), + ("Hubspot", "Hubspot"), + ("Postgres", "Postgres"), + ("Zendesk", "Zendesk"), + ("Snowflake", "Snowflake"), + ("MySQL", "MySQL"), + ], + max_length=128, + ), + ), + ] diff --git a/posthog/temporal/data_imports/pipelines/schemas.py b/posthog/temporal/data_imports/pipelines/schemas.py index 7dccb65eca59b9..8c0355b34d6edc 100644 --- a/posthog/temporal/data_imports/pipelines/schemas.py +++ b/posthog/temporal/data_imports/pipelines/schemas.py @@ -21,6 +21,7 @@ ), ExternalDataSource.Type.POSTGRES: (), ExternalDataSource.Type.SNOWFLAKE: (), + ExternalDataSource.Type.MYSQL: (), } PIPELINE_TYPE_INCREMENTAL_ENDPOINTS_MAPPING = { @@ -29,6 +30,7 @@ ExternalDataSource.Type.ZENDESK: ZENDESK_INCREMENTAL_ENDPOINTS, ExternalDataSource.Type.POSTGRES: (), ExternalDataSource.Type.SNOWFLAKE: (), + ExternalDataSource.Type.MYSQL: (), } PIPELINE_TYPE_INCREMENTAL_FIELDS_MAPPING: dict[ExternalDataSource.Type, dict[str, list[IncrementalField]]] = { @@ -37,4 +39,5 @@ ExternalDataSource.Type.ZENDESK: ZENDESK_INCREMENTAL_FIELDS, ExternalDataSource.Type.POSTGRES: {}, ExternalDataSource.Type.SNOWFLAKE: {}, + ExternalDataSource.Type.MYSQL: {}, } diff --git a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py index 04fb8885701dac..858872fe3ee6ec 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py @@ -16,6 +16,7 @@ from urllib.parse import quote from posthog.warehouse.types import IncrementalFieldType +from posthog.warehouse.models.external_data_source import ExternalDataSource from sqlalchemy.sql import text from .helpers import ( @@ -35,7 +36,8 @@ def incremental_type_to_initial_value(field_type: IncrementalFieldType) -> Any: return date(1970, 1, 1) -def postgres_source( +def sql_source_for_type( + source_type: ExternalDataSource.Type, host: str, port: int, user: str, @@ -53,10 +55,6 @@ def postgres_source( database = quote(database) sslmode = quote(sslmode) - credentials = ConnectionStringCredentials( - f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={sslmode}" - ) - if incremental_field is not None and incremental_field_type is not None: incremental: dlt.sources.incremental | None = dlt.sources.incremental( cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type) @@ -64,6 +62,15 @@ def postgres_source( else: incremental = None + if source_type == ExternalDataSource.Type.POSTGRES: + credentials = ConnectionStringCredentials( + f"postgresql://{user}:{password}@{host}:{port}/{database}?sslmode={sslmode}" + ) + elif source_type == ExternalDataSource.Type.MYSQL: + credentials = ConnectionStringCredentials(f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}") + else: + raise Exception("Unsupported source_type") + db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental) return db_source diff --git a/posthog/temporal/data_imports/workflow_activities/create_job_model.py b/posthog/temporal/data_imports/workflow_activities/create_job_model.py index a35bb1667e7b08..21f5e046d1a28e 100644 --- a/posthog/temporal/data_imports/workflow_activities/create_job_model.py +++ b/posthog/temporal/data_imports/workflow_activities/create_job_model.py @@ -13,7 +13,7 @@ from posthog.warehouse.models import sync_old_schemas_with_new_schemas, ExternalDataSource, aget_schema_by_id from posthog.warehouse.models.external_data_schema import ( ExternalDataSchema, - get_postgres_schemas, + get_sql_schemas_for_source_type, get_snowflake_schemas, ) from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -46,7 +46,7 @@ async def create_external_data_job_model_activity(inputs: CreateExternalDataJobM source = await sync_to_async(ExternalDataSource.objects.get)(team_id=inputs.team_id, id=inputs.source_id) - if source.source_type == ExternalDataSource.Type.POSTGRES: + if source.source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: host = source.job_inputs.get("host") port = source.job_inputs.get("port") user = source.job_inputs.get("user") @@ -74,8 +74,8 @@ async def create_external_data_job_model_activity(inputs: CreateExternalDataJobM private_key=ssh_tunnel_auth_type_private_key, ) - schemas_to_sync = await sync_to_async(get_postgres_schemas)( - host, port, database, user, password, db_schema, ssh_tunnel + schemas_to_sync = await sync_to_async(get_sql_schemas_for_source_type)( + source.source_type, host, port, database, user, password, db_schema, ssh_tunnel ) elif source.source_type == ExternalDataSource.Type.SNOWFLAKE: account_id = source.job_inputs.get("account_id") diff --git a/posthog/temporal/data_imports/workflow_activities/import_data.py b/posthog/temporal/data_imports/workflow_activities/import_data.py index 9849339e785c72..190a35e3ab673b 100644 --- a/posthog/temporal/data_imports/workflow_activities/import_data.py +++ b/posthog/temporal/data_imports/workflow_activities/import_data.py @@ -102,8 +102,8 @@ async def import_data_activity(inputs: ImportDataActivityInputs): schema=schema, reset_pipeline=reset_pipeline, ) - elif model.pipeline.source_type == ExternalDataSource.Type.POSTGRES: - from posthog.temporal.data_imports.pipelines.sql_database import postgres_source + elif model.pipeline.source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: + from posthog.temporal.data_imports.pipelines.sql_database import sql_source_for_type host = model.pipeline.job_inputs.get("host") port = model.pipeline.job_inputs.get("port") @@ -137,7 +137,8 @@ async def import_data_activity(inputs: ImportDataActivityInputs): if tunnel is None: raise Exception("Can't open tunnel to SSH server") - source = postgres_source( + source = sql_source_for_type( + source_type=model.pipeline.source_type, host=tunnel.local_bind_host, port=tunnel.local_bind_port, user=user, @@ -163,7 +164,8 @@ async def import_data_activity(inputs: ImportDataActivityInputs): reset_pipeline=reset_pipeline, ) - source = postgres_source( + source = sql_source_for_type( + source_type=model.pipeline.source_type, host=host, port=port, user=user, diff --git a/posthog/temporal/tests/batch_exports/test_import_data.py b/posthog/temporal/tests/batch_exports/test_import_data.py index 2b743102056d0a..935781c3bdf342 100644 --- a/posthog/temporal/tests/batch_exports/test_import_data.py +++ b/posthog/temporal/tests/batch_exports/test_import_data.py @@ -70,12 +70,13 @@ async def test_postgres_source_without_ssh_tunnel(activity_environment, team, ** activity_inputs = await _setup(team, job_inputs) with ( - mock.patch("posthog.temporal.data_imports.pipelines.sql_database.postgres_source") as postgres_source, + mock.patch("posthog.temporal.data_imports.pipelines.sql_database.sql_source_for_type") as sql_source_for_type, mock.patch("posthog.temporal.data_imports.workflow_activities.import_data._run"), ): await activity_environment.run(import_data_activity, activity_inputs) - postgres_source.assert_called_once_with( + sql_source_for_type.assert_called_once_with( + source_type=ExternalDataSource.Type.POSTGRES, host="host.com", port="5432", user="Username", @@ -107,12 +108,13 @@ async def test_postgres_source_with_ssh_tunnel_disabled(activity_environment, te activity_inputs = await _setup(team, job_inputs) with ( - mock.patch("posthog.temporal.data_imports.pipelines.sql_database.postgres_source") as postgres_source, + mock.patch("posthog.temporal.data_imports.pipelines.sql_database.sql_source_for_type") as sql_source_for_type, mock.patch("posthog.temporal.data_imports.workflow_activities.import_data._run"), ): await activity_environment.run(import_data_activity, activity_inputs) - postgres_source.assert_called_once_with( + sql_source_for_type.assert_called_once_with( + source_type=ExternalDataSource.Type.POSTGRES, host="host.com", port="5432", user="Username", @@ -160,13 +162,14 @@ def __exit__(self, exc_type, exc_value, exc_traceback): return MockedTunnel() with ( - mock.patch("posthog.temporal.data_imports.pipelines.sql_database.postgres_source") as postgres_source, + mock.patch("posthog.temporal.data_imports.pipelines.sql_database.sql_source_for_type") as sql_source_for_type, mock.patch("posthog.temporal.data_imports.workflow_activities.import_data._run"), mock.patch.object(SSHTunnel, "get_tunnel", mock_get_tunnel), ): await activity_environment.run(import_data_activity, activity_inputs) - postgres_source.assert_called_once_with( + sql_source_for_type.assert_called_once_with( + source_type=ExternalDataSource.Type.POSTGRES, host="other-host.com", port=55550, user="Username", diff --git a/posthog/warehouse/api/external_data_schema.py b/posthog/warehouse/api/external_data_schema.py index eaba69507f3927..e85f303b24ccba 100644 --- a/posthog/warehouse/api/external_data_schema.py +++ b/posthog/warehouse/api/external_data_schema.py @@ -24,10 +24,11 @@ delete_data_import_folder, ) from posthog.warehouse.models.external_data_schema import ( + filter_mysql_incremental_fields, filter_postgres_incremental_fields, filter_snowflake_incremental_fields, - get_postgres_schemas, get_snowflake_schemas, + get_sql_schemas_for_source_type, ) from posthog.warehouse.models.external_data_source import ExternalDataSource from posthog.warehouse.models.ssh_tunnel import SSHTunnel @@ -253,7 +254,7 @@ def incremental_fields(self, request: Request, *args: Any, **kwargs: Any): source: ExternalDataSource = instance.source incremental_columns: list[IncrementalField] = [] - if source.source_type == ExternalDataSource.Type.POSTGRES: + if source.source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: # TODO(@Gilbert09): Move all this into a util and replace elsewhere host = source.job_inputs.get("host") port = source.job_inputs.get("port") @@ -282,7 +283,8 @@ def incremental_fields(self, request: Request, *args: Any, **kwargs: Any): private_key=ssh_tunnel_auth_type_private_key, ) - pg_schemas = get_postgres_schemas( + db_schemas = get_sql_schemas_for_source_type( + source.source_type, host=host, port=port, database=database, @@ -292,10 +294,15 @@ def incremental_fields(self, request: Request, *args: Any, **kwargs: Any): ssh_tunnel=ssh_tunnel, ) - columns = pg_schemas.get(instance.name, []) + columns = db_schemas.get(instance.name, []) + if source.source_type == ExternalDataSource.Type.POSTGRES: + incremental_fields_func = filter_postgres_incremental_fields + else: + incremental_fields_func = filter_mysql_incremental_fields + incremental_columns = [ {"field": name, "field_type": field_type, "label": name, "type": field_type} - for name, field_type in filter_postgres_incremental_fields(columns) + for name, field_type in incremental_fields_func(columns) ] elif source.source_type == ExternalDataSource.Type.SNOWFLAKE: # TODO(@Gilbert09): Move all this into a util and replace elsewhere diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py index af3df2ec43ee22..059e1fe271154a 100644 --- a/posthog/warehouse/api/external_data_source.py +++ b/posthog/warehouse/api/external_data_source.py @@ -34,7 +34,7 @@ from posthog.warehouse.models.external_data_schema import ( filter_postgres_incremental_fields, filter_snowflake_incremental_fields, - get_postgres_schemas, + get_sql_schemas_for_source_type, get_snowflake_schemas, ) @@ -50,7 +50,16 @@ logger = structlog.get_logger(__name__) -GenericPostgresError = "Could not connect to Postgres. Please check all connection details are valid." + +def get_generic_sql_error(source_type: ExternalDataSource.Type): + if source_type == ExternalDataSource.Type.MYSQL: + name = "MySQL" + else: + name = "Postgres" + + return f"Could not connect to {name}. Please check all connection details are valid." + + GenericSnowflakeError = "Could not connect to Snowflake. Please check all connection details are valid." PostgresErrors = { "password authentication failed for user": "Invalid user or password", @@ -248,9 +257,9 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: new_source_model = self._handle_hubspot_source(request, *args, **kwargs) elif source_type == ExternalDataSource.Type.ZENDESK: new_source_model = self._handle_zendesk_source(request, *args, **kwargs) - elif source_type == ExternalDataSource.Type.POSTGRES: + elif source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: try: - new_source_model, postgres_schemas = self._handle_postgres_source(request, *args, **kwargs) + new_source_model, sql_schemas = self._handle_sql_source(request, *args, **kwargs) except InternalPostgresError: return Response( status=status.HTTP_400_BAD_REQUEST, data={"message": "Cannot use internal Postgres database"} @@ -264,8 +273,8 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: payload = request.data["payload"] schemas = payload.get("schemas", None) - if source_type == ExternalDataSource.Type.POSTGRES: - default_schemas = postgres_schemas + if source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: + default_schemas = sql_schemas elif source_type == ExternalDataSource.Type.SNOWFLAKE: default_schemas = snowflake_schemas else: @@ -408,9 +417,7 @@ def _handle_hubspot_source(self, request: Request, *args: Any, **kwargs: Any) -> return new_source_model - def _handle_postgres_source( - self, request: Request, *args: Any, **kwargs: Any - ) -> tuple[ExternalDataSource, list[Any]]: + def _handle_sql_source(self, request: Request, *args: Any, **kwargs: Any) -> tuple[ExternalDataSource, list[Any]]: payload = request.data["payload"] prefix = request.data.get("prefix", None) source_type = request.data["source_type"] @@ -474,7 +481,16 @@ def _handle_postgres_source( private_key=ssh_tunnel_auth_type_private_key, ) - schemas = get_postgres_schemas(host, port, database, user, password, schema, ssh_tunnel) + schemas = get_sql_schemas_for_source_type( + source_type, + host, + port, + database, + user, + password, + schema, + ssh_tunnel, + ) return new_source_model, schemas @@ -609,7 +625,7 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): ) # Get schemas and validate SQL credentials - if source_type == ExternalDataSource.Type.POSTGRES: + if source_type in [ExternalDataSource.Type.POSTGRES, ExternalDataSource.Type.MYSQL]: host = request.data.get("host", None) port = request.data.get("port", None) database = request.data.get("dbname", None) @@ -677,11 +693,20 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): ) try: - result = get_postgres_schemas(host, port, database, user, password, schema, ssh_tunnel) + result = get_sql_schemas_for_source_type( + source_type, + host, + port, + database, + user, + password, + schema, + ssh_tunnel, + ) if len(result.keys()) == 0: return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": "Postgres schema doesn't exist"}, + data={"message": "Schema doesn't exist"}, ) except OperationalError as e: exposed_error = self._expose_postgres_error(e) @@ -691,12 +716,12 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": exposed_error or GenericPostgresError}, + data={"message": exposed_error or get_generic_sql_error(source_type)}, ) except BaseSSHTunnelForwarderError as e: return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": e.value or GenericPostgresError}, + data={"message": e.value or get_generic_sql_error(source_type)}, ) except Exception as e: capture_exception(e) @@ -704,7 +729,7 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": GenericPostgresError}, + data={"message": get_generic_sql_error(source_type)}, ) filtered_results = [ diff --git a/posthog/warehouse/api/test/test_external_data_source.py b/posthog/warehouse/api/test/test_external_data_source.py index 84517cafb32af3..2ddadcb0398743 100644 --- a/posthog/warehouse/api/test/test_external_data_source.py +++ b/posthog/warehouse/api/test/test_external_data_source.py @@ -584,9 +584,10 @@ def test_database_schema_non_postgres_source(self): assert table in table_names @patch( - "posthog.warehouse.api.external_data_source.get_postgres_schemas", return_value={"table_1": [("id", "integer")]} + "posthog.warehouse.api.external_data_source.get_sql_schemas_for_source_type", + return_value={"table_1": [("id", "integer")]}, ) - def test_internal_postgres(self, patch_get_postgres_schemas): + def test_internal_postgres(self, patch_get_sql_schemas_for_source_type): # This test checks handling of project ID 2 in Cloud US and project ID 1 in Cloud EU, # so let's make sure there are no projects with these IDs in the test DB Project.objects.filter(id__in=[1, 2]).delete() diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index fbb65500192d78..f47c24277237e2 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -6,6 +6,8 @@ from posthog.models.utils import CreatedMetaFields, UUIDModel, sane_repr import uuid import psycopg2 +import pymysql +from .external_data_source import ExternalDataSource from posthog.warehouse.data_load.service import ( external_data_workflow_exists, pause_external_data_schedule, @@ -222,3 +224,83 @@ def get_schemas(postgres_host: str, postgres_port: int): return get_schemas(tunnel.local_bind_host, tunnel.local_bind_port) return get_schemas(host, int(port)) + + +def filter_mysql_incremental_fields(columns: list[tuple[str, str]]) -> list[tuple[str, IncrementalFieldType]]: + results: list[tuple[str, IncrementalFieldType]] = [] + for column_name, type in columns: + type = type.lower() + if type.startswith("timestamp"): + results.append((column_name, IncrementalFieldType.Timestamp)) + elif type == "date": + results.append((column_name, IncrementalFieldType.Date)) + elif type == "datetime": + results.append((column_name, IncrementalFieldType.DateTime)) + elif type == "tinyint" or type == "smallint" or type == "mediumint" or type == "int" or type == "bigint": + results.append((column_name, IncrementalFieldType.Integer)) + + return results + + +def get_mysql_schemas( + host: str, + port: str, + database: str, + user: str, + password: str, + schema: str, + ssh_tunnel: SSHTunnel, +) -> dict[str, list[tuple[str, str]]]: + def get_schemas(mysql_host: str, mysql_port: int): + connection = pymysql.connect( + host=mysql_host, + port=mysql_port, + database=database, + user=user, + password=password, + connect_timeout=5, + ) + + with connection.cursor() as cursor: + cursor.execute( + "SELECT table_name, column_name, data_type FROM information_schema.columns WHERE table_schema = %(schema)s ORDER BY table_name ASC", + {"schema": schema}, + ) + result = cursor.fetchall() + + schema_list = defaultdict(list) + for row in result: + schema_list[row[0]].append((row[1], row[2])) + + connection.close() + + return schema_list + + if ssh_tunnel.enabled: + with ssh_tunnel.get_tunnel(host, int(port)) as tunnel: + if tunnel is None: + raise Exception("Can't open tunnel to SSH server") + + return get_schemas(tunnel.local_bind_host, tunnel.local_bind_port) + + return get_schemas(host, int(port)) + + +def get_sql_schemas_for_source_type( + source_type: ExternalDataSource.Type, + host: str, + port: str, + database: str, + user: str, + password: str, + schema: str, + ssh_tunnel: SSHTunnel, +) -> dict[str, list[tuple[str, str]]]: + if source_type == ExternalDataSource.Type.POSTGRES: + schemas = get_postgres_schemas(host, port, database, user, password, schema, ssh_tunnel) + elif source_type == ExternalDataSource.Type.MYSQL: + schemas = get_mysql_schemas(host, port, database, user, password, schema, ssh_tunnel) + else: + raise Exception("Unsupported source_type") + + return schemas diff --git a/posthog/warehouse/models/external_data_source.py b/posthog/warehouse/models/external_data_source.py index dc21af8db26add..f9ffb21d41b298 100644 --- a/posthog/warehouse/models/external_data_source.py +++ b/posthog/warehouse/models/external_data_source.py @@ -19,6 +19,7 @@ class Type(models.TextChoices): POSTGRES = "Postgres", "Postgres" ZENDESK = "Zendesk", "Zendesk" SNOWFLAKE = "Snowflake", "Snowflake" + MYSQL = "MySQL", "MySQL" class Status(models.TextChoices): RUNNING = "Running", "Running" diff --git a/requirements-dev.in b/requirements-dev.in index 9ab0252aecf67f..5ca5431dbaf1c8 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -26,6 +26,7 @@ packaging==23.1 black~=23.9.1 boto3-stubs[s3] types-markdown==3.3.9 +types-PyMySQL==1.1.0.20240524 types-PyYAML==6.0.1 types-freezegun==1.1.10 types-paramiko==3.4.0.20240423 diff --git a/requirements-dev.txt b/requirements-dev.txt index a528eb65d50a7a..938eaead5395c1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -37,7 +37,7 @@ black==23.9.1 # -r requirements-dev.in # datamodel-code-generator # inline-snapshot -boto3-stubs[s3]==1.34.84 +boto3-stubs==1.34.84 # via -r requirements-dev.in botocore-stubs==1.34.84 # via boto3-stubs @@ -62,7 +62,7 @@ click==8.1.7 # inline-snapshot colorama==0.4.4 # via pytest-watch -coverage[toml]==5.5 +coverage==5.5 # via pytest-cov cryptography==39.0.2 # via @@ -98,6 +98,7 @@ executing==2.0.1 faker==17.5.0 # via -r requirements-dev.in fakeredis==2.23.3 + # via -r requirements-dev.in flaky==3.7.0 # via -r requirements-dev.in freezegun==1.2.2 @@ -197,7 +198,7 @@ pycparser==2.20 # via # -c requirements.txt # cffi -pydantic[email]==2.5.3 +pydantic==2.5.3 # via # -c requirements.txt # datamodel-code-generator @@ -313,6 +314,8 @@ types-markdown==3.3.9 # via -r requirements-dev.in types-paramiko==3.4.0.20240423 # via -r requirements-dev.in +types-pymysql==1.1.0.20240524 + # via -r requirements-dev.in types-python-dateutil==2.8.3 # via -r requirements-dev.in types-pytz==2023.3.0.0 diff --git a/requirements.in b/requirements.in index 9d5e15a4e59a57..03656f814bcf84 100644 --- a/requirements.in +++ b/requirements.in @@ -56,6 +56,7 @@ paramiko==3.4.0 Pillow==10.2.0 posthoganalytics==3.5.0 psycopg2-binary==2.9.7 +PyMySQL==1.1.1 psycopg[binary]==3.1.18 pyarrow==15.0.0 pydantic==2.5.3 diff --git a/requirements.txt b/requirements.txt index 92b14db6600f4d..a7f9721004c84e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -414,6 +414,7 @@ psycopg-binary==3.1.18 # via psycopg psycopg2-binary==2.9.7 # via -r requirements.in +PyMySQL==1.1.1 py==1.11.0 # via retry pyarrow==15.0.0