From 013ea270b160d7913b12d46fb228c7dbfc901f70 Mon Sep 17 00:00:00 2001 From: Tom Owers Date: Wed, 4 Dec 2024 18:01:54 +0000 Subject: [PATCH] Allow key pair auth for snowflake sources --- .../data-warehouse/new/sourceWizardLogic.tsx | 64 ++++++++++++++--- .../pipelines/sql_database/__init__.py | 60 ++++++++++++---- .../pipelines/sql_database_v2/__init__.py | 60 ++++++++++++---- .../workflow_activities/import_data_sync.py | 15 ++-- .../workflow_activities/sync_new_schemas.py | 21 +++++- posthog/warehouse/api/external_data_source.py | 71 +++++++++++++++---- .../warehouse/models/external_data_schema.py | 40 +++++++++-- 7 files changed, 273 insertions(+), 58 deletions(-) diff --git a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx index a924b0ba594b7..f8e1a5a131205 100644 --- a/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx +++ b/frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx @@ -513,18 +513,60 @@ export const SOURCE_DETAILS: Record = { placeholder: 'COMPUTE_WAREHOUSE', }, { - name: 'user', - label: 'User', - type: 'text', - required: true, - placeholder: 'user', - }, - { - name: 'password', - label: 'Password', - type: 'password', + type: 'select', + name: 'auth_type', + label: 'Authentication type', required: true, - placeholder: '', + defaultValue: 'password', + options: [ + { + label: 'Password', + value: 'password', + fields: [ + { + name: 'username', + label: 'Username', + type: 'text', + required: true, + placeholder: 'User1', + }, + { + name: 'password', + label: 'Password', + type: 'password', + required: true, + placeholder: '', + }, + ], + }, + { + label: 'Key pair', + value: 'keypair', + fields: [ + { + name: 'username', + label: 'Username', + type: 'text', + required: true, + placeholder: 'User1', + }, + { + name: 'private_key', + label: 'Private key', + type: 'textarea', + required: true, + placeholder: '', + }, + { + name: 'passphrase', + label: 'Passphrase', + type: 'password', + required: false, + placeholder: '', + }, + ], + }, + ], }, { name: 'role', diff --git a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py index fbf6a0ee01683..04aa7c9678c0b 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py @@ -21,6 +21,9 @@ from posthog.warehouse.models.external_data_source import ExternalDataSource from sqlalchemy.sql import text +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + from .helpers import ( table_rows, engine_from_credentials, @@ -109,8 +112,11 @@ def sql_source_for_type( def snowflake_source( account_id: str, - user: str, - password: str, + user: Optional[str], + password: Optional[str], + passphrase: Optional[str], + private_key: Optional[str], + auth_type: str, database: str, warehouse: str, schema: str, @@ -119,13 +125,6 @@ def snowflake_source( incremental_field: Optional[str] = None, incremental_field_type: Optional[IncrementalFieldType] = None, ) -> DltSource: - account_id = quote(account_id) - user = quote(user) - password = quote(password) - database = quote(database) - warehouse = quote(warehouse) - role = quote(role) if role else None - if incremental_field is not None and incremental_field_type is not None: incremental: dlt.sources.incremental | None = dlt.sources.incremental( cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type) @@ -133,9 +132,46 @@ def snowflake_source( else: incremental = None - credentials = ConnectionStringCredentials( - f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}" - ) + if auth_type == "password" and user is not None and password is not None: + account_id = quote(account_id) + user = quote(user) + password = quote(password) + database = quote(database) + warehouse = quote(warehouse) + role = quote(role) if role else None + + credentials = create_engine( + f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}" + ) + else: + assert private_key is not None + assert user is not None + + account_id = quote(account_id) + user = quote(user) + database = quote(database) + warehouse = quote(warehouse) + role = quote(role) if role else None + + p_key = serialization.load_pem_private_key( + private_key.encode("utf-8"), + password=passphrase.encode() if passphrase is not None else None, + backend=default_backend(), + ) + + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + credentials = create_engine( + f"snowflake://{user}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}", + connect_args={ + "private_key": pkb, + }, + ) + db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental) return db_source diff --git a/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py index cfd5f4cb822c0..bcab4c3e19282 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py @@ -20,6 +20,9 @@ from posthog.warehouse.models import ExternalDataSource from posthog.warehouse.types import IncrementalFieldType +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization + from .helpers import ( SelectAny, table_rows, @@ -125,8 +128,11 @@ def sql_source_for_type( def snowflake_source( account_id: str, - user: str, - password: str, + user: Optional[str], + password: Optional[str], + passphrase: Optional[str], + private_key: Optional[str], + auth_type: str, database: str, warehouse: str, schema: str, @@ -135,13 +141,6 @@ def snowflake_source( incremental_field: Optional[str] = None, incremental_field_type: Optional[IncrementalFieldType] = None, ) -> DltSource: - account_id = quote(account_id) - user = quote(user) - password = quote(password) - database = quote(database) - warehouse = quote(warehouse) - role = quote(role) if role else None - if incremental_field is not None and incremental_field_type is not None: incremental: dlt.sources.incremental | None = dlt.sources.incremental( cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type) @@ -149,9 +148,46 @@ def snowflake_source( else: incremental = None - credentials = ConnectionStringCredentials( - f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}" - ) + if auth_type == "password" and user is not None and password is not None: + account_id = quote(account_id) + user = quote(user) + password = quote(password) + database = quote(database) + warehouse = quote(warehouse) + role = quote(role) if role else None + + credentials = create_engine( + f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}" + ) + else: + assert private_key is not None + assert user is not None + + account_id = quote(account_id) + user = quote(user) + database = quote(database) + warehouse = quote(warehouse) + role = quote(role) if role else None + + p_key = serialization.load_pem_private_key( + private_key.encode("utf-8"), + password=passphrase.encode() if passphrase is not None else None, + backend=default_backend(), + ) + + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + credentials = create_engine( + f"snowflake://{user}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}", + connect_args={ + "private_key": pkb, + }, + ) + db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental) return db_source diff --git a/posthog/temporal/data_imports/workflow_activities/import_data_sync.py b/posthog/temporal/data_imports/workflow_activities/import_data_sync.py index 091828d34ba52..eeaf3b9f65de1 100644 --- a/posthog/temporal/data_imports/workflow_activities/import_data_sync.py +++ b/posthog/temporal/data_imports/workflow_activities/import_data_sync.py @@ -228,17 +228,24 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): ) account_id = model.pipeline.job_inputs.get("account_id") - user = model.pipeline.job_inputs.get("user") - password = model.pipeline.job_inputs.get("password") database = model.pipeline.job_inputs.get("database") warehouse = model.pipeline.job_inputs.get("warehouse") sf_schema = model.pipeline.job_inputs.get("schema") role = model.pipeline.job_inputs.get("role") + auth_type = model.pipeline.job_inputs.get("auth_type", "password") + auth_type_username = model.pipeline.job_inputs.get("user") + auth_type_password = model.pipeline.job_inputs.get("password") + auth_type_passphrase = model.pipeline.job_inputs.get("passphrase") + auth_type_private_key = model.pipeline.job_inputs.get("private_key") + source = snowflake_source( account_id=account_id, - user=user, - password=password, + auth_type=auth_type, + user=auth_type_username, + password=auth_type_password, + private_key=auth_type_private_key, + passphrase=auth_type_passphrase, database=database, schema=sf_schema, warehouse=warehouse, diff --git a/posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py b/posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py index 67f8c820e2837..b63d7ea869e16 100644 --- a/posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py +++ b/posthog/temporal/data_imports/workflow_activities/sync_new_schemas.py @@ -86,14 +86,29 @@ def sync_new_schemas_activity(inputs: SyncNewSchemasActivityInputs) -> None: return account_id = source.job_inputs.get("account_id") - user = source.job_inputs.get("user") - password = source.job_inputs.get("password") database = source.job_inputs.get("database") warehouse = source.job_inputs.get("warehouse") sf_schema = source.job_inputs.get("schema") role = source.job_inputs.get("role") - sql_schemas = get_snowflake_schemas(account_id, database, warehouse, user, password, sf_schema, role) + auth_type = source.job_inputs.get("auth_type", "password") + auth_type_username = source.job_inputs.get("user") + auth_type_password = source.job_inputs.get("password") + auth_type_passphrase = source.job_inputs.get("passphrase") + auth_type_private_key = source.job_inputs.get("private_key") + + sql_schemas = get_snowflake_schemas( + account_id=account_id, + database=database, + warehouse=warehouse, + user=auth_type_username, + password=auth_type_password, + schema=sf_schema, + role=role, + auth_type=auth_type, + passphrase=auth_type_passphrase, + private_key=auth_type_private_key, + ) schemas_to_sync = list(sql_schemas.keys()) else: diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py index cc1f38170d327..448c06533bf19 100644 --- a/posthog/warehouse/api/external_data_source.py +++ b/posthog/warehouse/api/external_data_source.py @@ -640,10 +640,15 @@ def _handle_snowflake_source( database = payload.get("database") warehouse = payload.get("warehouse") role = payload.get("role") - user = payload.get("user") - password = payload.get("password") schema = payload.get("schema") + auth_type_obj = payload.get("auth_type", {}) + auth_type = auth_type_obj.get("selection", None) + auth_type_username = auth_type_obj.get("username", None) + auth_type_password = auth_type_obj.get("password", None) + auth_type_passphrase = auth_type_obj.get("passphrase", None) + auth_type_private_key = auth_type_obj.get("private_key", None) + new_source_model = ExternalDataSource.objects.create( source_id=str(uuid.uuid4()), connection_id=str(uuid.uuid4()), @@ -656,14 +661,28 @@ def _handle_snowflake_source( "database": database, "warehouse": warehouse, "role": role, - "user": user, - "password": password, "schema": schema, + "auth_type": auth_type, + "user": auth_type_username, + "password": auth_type_password, + "passphrase": auth_type_passphrase, + "private_key": auth_type_private_key, }, prefix=prefix, ) - schemas = get_snowflake_schemas(account_id, database, warehouse, user, password, schema, role) + schemas = get_snowflake_schemas( + account_id=account_id, + database=database, + warehouse=warehouse, + user=auth_type_username, + password=auth_type_password, + schema=schema, + role=role, + passphrase=auth_type_passphrase, + private_key=auth_type_private_key, + auth_type=auth_type, + ) return new_source_model, list(schemas.keys()) @@ -1068,20 +1087,48 @@ def database_schema(self, request: Request, *arg: Any, **kwargs: Any): database = request.data.get("database") warehouse = request.data.get("warehouse") role = request.data.get("role") - user = request.data.get("user") - password = request.data.get("password") schema = request.data.get("schema") - if not account_id or not warehouse or not database or not user or not password or not schema: + auth_type_obj = request.data.get("auth_type", {}) + auth_type = auth_type_obj.get("selection", None) + auth_type_username = auth_type_obj.get("username", None) + auth_type_password = auth_type_obj.get("password", None) + auth_type_passphrase = auth_type_obj.get("passphrase", None) + auth_type_private_key = auth_type_obj.get("private_key", None) + + if not account_id or not warehouse or not database or not schema: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": "Missing required parameters: account id, warehouse, database, schema"}, + ) + + if auth_type == "password" and (not auth_type_username or not auth_type_password): + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": "Missing required parameters: username, password"}, + ) + + if auth_type == "keypair" and ( + not auth_type_passphrase or not auth_type_private_key or not auth_type_username + ): return Response( status=status.HTTP_400_BAD_REQUEST, - data={ - "message": "Missing required parameters: account id, warehouse, database, user, password, schema" - }, + data={"message": "Missing required parameters: passphrase, private key"}, ) try: - result = get_snowflake_schemas(account_id, database, warehouse, user, password, schema, role) + result = get_snowflake_schemas( + account_id=account_id, + database=database, + warehouse=warehouse, + user=auth_type_username, + password=auth_type_password, + schema=schema, + role=role, + passphrase=auth_type_passphrase, + private_key=auth_type_private_key, + auth_type=auth_type, + ) if len(result.keys()) == 0: return Response( status=status.HTTP_400_BAD_REQUEST, diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index b8219a6628c02..3b081d54f8501 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -1,5 +1,7 @@ from collections import defaultdict from datetime import datetime, timedelta +import tempfile +import os from typing import Optional from django.db import models from django_deprecate_fields import deprecate_field @@ -193,16 +195,43 @@ def filter_snowflake_incremental_fields(columns: list[tuple[str, str]]) -> list[ def get_snowflake_schemas( - account_id: str, database: str, warehouse: str, user: str, password: str, schema: str, role: Optional[str] = None + account_id: str, + database: str, + warehouse: str, + user: Optional[str], + password: Optional[str], + passphrase: Optional[str], + private_key: Optional[str], + auth_type: str, + schema: str, + role: Optional[str] = None, ) -> dict[str, list[tuple[str, str]]]: + auth_connect_args: dict[str, str | None] = {} + file_name: str | None = None + + if auth_type == "keypair" and private_key is not None: + with tempfile.NamedTemporaryFile(delete=False) as tf: + tf.write(private_key.encode("utf-8")) + file_name = tf.name + + auth_connect_args = { + "user": user, + "private_key_file": file_name, + "private_key_file_pwd": passphrase, + } + else: + auth_connect_args = { + "password": password, + "user": user, + } + with snowflake.connector.connect( - user=user, - password=password, account=account_id, warehouse=warehouse, database=database, schema="information_schema", role=role, + **auth_connect_args, ) as connection: with connection.cursor() as cursor: if cursor is None: @@ -218,7 +247,10 @@ def get_snowflake_schemas( for row in result: schema_list[row[0]].append((row[1], row[2])) - return schema_list + if file_name is not None: + os.unlink(file_name) + + return schema_list def filter_postgres_incremental_fields(columns: list[tuple[str, str]]) -> list[tuple[str, IncrementalFieldType]]: