Skip to content

Commit

Permalink
Allow key pair auth for snowflake sources
Browse files Browse the repository at this point in the history
  • Loading branch information
Gilbert09 committed Dec 4, 2024
1 parent 8eef695 commit 013ea27
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 58 deletions.
64 changes: 53 additions & 11 deletions frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -513,18 +513,60 @@ export const SOURCE_DETAILS: Record<ExternalDataSourceType, SourceConfig> = {
placeholder: 'COMPUTE_WAREHOUSE',
},
{
name: 'user',
label: 'User',
type: 'text',
required: true,
placeholder: 'user',
},
{
name: 'password',
label: 'Password',
type: 'password',
type: 'select',
name: 'auth_type',
label: 'Authentication type',
required: true,
placeholder: '',
defaultValue: 'password',
options: [
{
label: 'Password',
value: 'password',
fields: [
{
name: 'username',
label: 'Username',
type: 'text',
required: true,
placeholder: 'User1',
},
{
name: 'password',
label: 'Password',
type: 'password',
required: true,
placeholder: '',
},
],
},
{
label: 'Key pair',
value: 'keypair',
fields: [
{
name: 'username',
label: 'Username',
type: 'text',
required: true,
placeholder: 'User1',
},
{
name: 'private_key',
label: 'Private key',
type: 'textarea',
required: true,
placeholder: '',
},
{
name: 'passphrase',
label: 'Passphrase',
type: 'password',
required: false,
placeholder: '',
},
],
},
],
},
{
name: 'role',
Expand Down
60 changes: 48 additions & 12 deletions posthog/temporal/data_imports/pipelines/sql_database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from posthog.warehouse.models.external_data_source import ExternalDataSource
from sqlalchemy.sql import text

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from .helpers import (
table_rows,
engine_from_credentials,
Expand Down Expand Up @@ -109,8 +112,11 @@ def sql_source_for_type(

def snowflake_source(
account_id: str,
user: str,
password: str,
user: Optional[str],
password: Optional[str],
passphrase: Optional[str],
private_key: Optional[str],
auth_type: str,
database: str,
warehouse: str,
schema: str,
Expand All @@ -119,23 +125,53 @@ def snowflake_source(
incremental_field: Optional[str] = None,
incremental_field_type: Optional[IncrementalFieldType] = None,
) -> DltSource:
account_id = quote(account_id)
user = quote(user)
password = quote(password)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

if incremental_field is not None and incremental_field_type is not None:
incremental: dlt.sources.incremental | None = dlt.sources.incremental(
cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type)
)
else:
incremental = None

credentials = ConnectionStringCredentials(
f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}"
)
if auth_type == "password" and user is not None and password is not None:
account_id = quote(account_id)
user = quote(user)
password = quote(password)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

credentials = create_engine(
f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}"
)
else:
assert private_key is not None
assert user is not None

account_id = quote(account_id)
user = quote(user)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

p_key = serialization.load_pem_private_key(
private_key.encode("utf-8"),
password=passphrase.encode() if passphrase is not None else None,
backend=default_backend(),
)

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

credentials = create_engine(
f"snowflake://{user}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}",
connect_args={
"private_key": pkb,
},
)

db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental)

return db_source
Expand Down
60 changes: 48 additions & 12 deletions posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from posthog.warehouse.models import ExternalDataSource
from posthog.warehouse.types import IncrementalFieldType

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from .helpers import (
SelectAny,
table_rows,
Expand Down Expand Up @@ -125,8 +128,11 @@ def sql_source_for_type(

def snowflake_source(
account_id: str,
user: str,
password: str,
user: Optional[str],
password: Optional[str],
passphrase: Optional[str],
private_key: Optional[str],
auth_type: str,
database: str,
warehouse: str,
schema: str,
Expand All @@ -135,23 +141,53 @@ def snowflake_source(
incremental_field: Optional[str] = None,
incremental_field_type: Optional[IncrementalFieldType] = None,
) -> DltSource:
account_id = quote(account_id)
user = quote(user)
password = quote(password)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

if incremental_field is not None and incremental_field_type is not None:
incremental: dlt.sources.incremental | None = dlt.sources.incremental(
cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type)
)
else:
incremental = None

credentials = ConnectionStringCredentials(
f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}"
)
if auth_type == "password" and user is not None and password is not None:
account_id = quote(account_id)
user = quote(user)
password = quote(password)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

credentials = create_engine(
f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}"
)
else:
assert private_key is not None
assert user is not None

account_id = quote(account_id)
user = quote(user)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

p_key = serialization.load_pem_private_key(
private_key.encode("utf-8"),
password=passphrase.encode() if passphrase is not None else None,
backend=default_backend(),
)

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

credentials = create_engine(
f"snowflake://{user}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}",
connect_args={
"private_key": pkb,
},
)

db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental)

return db_source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,24 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs):
)

account_id = model.pipeline.job_inputs.get("account_id")
user = model.pipeline.job_inputs.get("user")
password = model.pipeline.job_inputs.get("password")
database = model.pipeline.job_inputs.get("database")
warehouse = model.pipeline.job_inputs.get("warehouse")
sf_schema = model.pipeline.job_inputs.get("schema")
role = model.pipeline.job_inputs.get("role")

auth_type = model.pipeline.job_inputs.get("auth_type", "password")
auth_type_username = model.pipeline.job_inputs.get("user")
auth_type_password = model.pipeline.job_inputs.get("password")
auth_type_passphrase = model.pipeline.job_inputs.get("passphrase")
auth_type_private_key = model.pipeline.job_inputs.get("private_key")

source = snowflake_source(
account_id=account_id,
user=user,
password=password,
auth_type=auth_type,
user=auth_type_username,
password=auth_type_password,
private_key=auth_type_private_key,
passphrase=auth_type_passphrase,
database=database,
schema=sf_schema,
warehouse=warehouse,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,29 @@ def sync_new_schemas_activity(inputs: SyncNewSchemasActivityInputs) -> None:
return

account_id = source.job_inputs.get("account_id")
user = source.job_inputs.get("user")
password = source.job_inputs.get("password")
database = source.job_inputs.get("database")
warehouse = source.job_inputs.get("warehouse")
sf_schema = source.job_inputs.get("schema")
role = source.job_inputs.get("role")

sql_schemas = get_snowflake_schemas(account_id, database, warehouse, user, password, sf_schema, role)
auth_type = source.job_inputs.get("auth_type", "password")
auth_type_username = source.job_inputs.get("user")
auth_type_password = source.job_inputs.get("password")
auth_type_passphrase = source.job_inputs.get("passphrase")
auth_type_private_key = source.job_inputs.get("private_key")

sql_schemas = get_snowflake_schemas(
account_id=account_id,
database=database,
warehouse=warehouse,
user=auth_type_username,
password=auth_type_password,
schema=sf_schema,
role=role,
auth_type=auth_type,
passphrase=auth_type_passphrase,
private_key=auth_type_private_key,
)

schemas_to_sync = list(sql_schemas.keys())
else:
Expand Down
Loading

0 comments on commit 013ea27

Please sign in to comment.