Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(data-warehouse): Allow key pair auth for snowflake sources #26652

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
64 changes: 53 additions & 11 deletions frontend/src/scenes/data-warehouse/new/sourceWizardLogic.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -513,18 +513,60 @@ export const SOURCE_DETAILS: Record<ExternalDataSourceType, SourceConfig> = {
placeholder: 'COMPUTE_WAREHOUSE',
},
{
name: 'user',
label: 'User',
type: 'text',
required: true,
placeholder: 'user',
},
{
name: 'password',
label: 'Password',
type: 'password',
type: 'select',
name: 'auth_type',
label: 'Authentication type',
required: true,
placeholder: '',
defaultValue: 'password',
options: [
{
label: 'Password',
value: 'password',
fields: [
{
name: 'username',
label: 'Username',
type: 'text',
required: true,
placeholder: 'User1',
},
{
name: 'password',
label: 'Password',
type: 'password',
required: true,
placeholder: '',
},
],
},
{
label: 'Key pair',
value: 'keypair',
fields: [
{
name: 'username',
label: 'Username',
type: 'text',
required: true,
placeholder: 'User1',
},
{
name: 'private_key',
label: 'Private key',
type: 'textarea',
required: true,
placeholder: '',
},
{
name: 'passphrase',
label: 'Passphrase',
type: 'password',
required: false,
placeholder: '',
},
],
},
],
},
{
name: 'role',
Expand Down
238 changes: 119 additions & 119 deletions mypy-baseline.txt

Large diffs are not rendered by default.

60 changes: 48 additions & 12 deletions posthog/temporal/data_imports/pipelines/sql_database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from posthog.warehouse.models.external_data_source import ExternalDataSource
from sqlalchemy.sql import text

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from .helpers import (
table_rows,
engine_from_credentials,
Expand Down Expand Up @@ -109,8 +112,11 @@ def sql_source_for_type(

def snowflake_source(
account_id: str,
user: str,
password: str,
user: Optional[str],
password: Optional[str],
passphrase: Optional[str],
private_key: Optional[str],
auth_type: str,
database: str,
warehouse: str,
schema: str,
Expand All @@ -119,23 +125,53 @@ def snowflake_source(
incremental_field: Optional[str] = None,
incremental_field_type: Optional[IncrementalFieldType] = None,
) -> DltSource:
account_id = quote(account_id)
user = quote(user)
password = quote(password)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

if incremental_field is not None and incremental_field_type is not None:
incremental: dlt.sources.incremental | None = dlt.sources.incremental(
cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type)
)
else:
incremental = None

credentials = ConnectionStringCredentials(
f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}"
)
if auth_type == "password" and user is not None and password is not None:
account_id = quote(account_id)
user = quote(user)
password = quote(password)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

credentials = create_engine(
f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}"
)
else:
assert private_key is not None
assert user is not None

account_id = quote(account_id)
user = quote(user)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

p_key = serialization.load_pem_private_key(
private_key.encode("utf-8"),
password=passphrase.encode() if passphrase is not None else None,
backend=default_backend(),
)

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

credentials = create_engine(
f"snowflake://{user}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}",
connect_args={
"private_key": pkb,
},
)

db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental)

return db_source
Expand Down
60 changes: 48 additions & 12 deletions posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from posthog.warehouse.models import ExternalDataSource
from posthog.warehouse.types import IncrementalFieldType

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from .helpers import (
SelectAny,
table_rows,
Expand Down Expand Up @@ -125,8 +128,11 @@ def sql_source_for_type(

def snowflake_source(
account_id: str,
user: str,
password: str,
user: Optional[str],
password: Optional[str],
passphrase: Optional[str],
private_key: Optional[str],
auth_type: str,
database: str,
warehouse: str,
schema: str,
Expand All @@ -135,23 +141,53 @@ def snowflake_source(
incremental_field: Optional[str] = None,
incremental_field_type: Optional[IncrementalFieldType] = None,
) -> DltSource:
account_id = quote(account_id)
user = quote(user)
password = quote(password)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

if incremental_field is not None and incremental_field_type is not None:
incremental: dlt.sources.incremental | None = dlt.sources.incremental(
cursor_path=incremental_field, initial_value=incremental_type_to_initial_value(incremental_field_type)
)
else:
incremental = None

credentials = ConnectionStringCredentials(
f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}"
)
if auth_type == "password" and user is not None and password is not None:
account_id = quote(account_id)
user = quote(user)
password = quote(password)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

credentials = create_engine(
f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}"
)
else:
assert private_key is not None
assert user is not None

account_id = quote(account_id)
user = quote(user)
database = quote(database)
warehouse = quote(warehouse)
role = quote(role) if role else None

p_key = serialization.load_pem_private_key(
private_key.encode("utf-8"),
password=passphrase.encode() if passphrase is not None else None,
backend=default_backend(),
)

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

credentials = create_engine(
f"snowflake://{user}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}",
connect_args={
"private_key": pkb,
},
)

db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental)

return db_source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,24 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs):
)

account_id = model.pipeline.job_inputs.get("account_id")
user = model.pipeline.job_inputs.get("user")
password = model.pipeline.job_inputs.get("password")
database = model.pipeline.job_inputs.get("database")
warehouse = model.pipeline.job_inputs.get("warehouse")
sf_schema = model.pipeline.job_inputs.get("schema")
role = model.pipeline.job_inputs.get("role")

auth_type = model.pipeline.job_inputs.get("auth_type", "password")
auth_type_username = model.pipeline.job_inputs.get("user")
auth_type_password = model.pipeline.job_inputs.get("password")
auth_type_passphrase = model.pipeline.job_inputs.get("passphrase")
auth_type_private_key = model.pipeline.job_inputs.get("private_key")

source = snowflake_source(
account_id=account_id,
user=user,
password=password,
auth_type=auth_type,
user=auth_type_username,
password=auth_type_password,
private_key=auth_type_private_key,
passphrase=auth_type_passphrase,
database=database,
schema=sf_schema,
warehouse=warehouse,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,29 @@ def sync_new_schemas_activity(inputs: SyncNewSchemasActivityInputs) -> None:
return

account_id = source.job_inputs.get("account_id")
user = source.job_inputs.get("user")
password = source.job_inputs.get("password")
database = source.job_inputs.get("database")
warehouse = source.job_inputs.get("warehouse")
sf_schema = source.job_inputs.get("schema")
role = source.job_inputs.get("role")

sql_schemas = get_snowflake_schemas(account_id, database, warehouse, user, password, sf_schema, role)
auth_type = source.job_inputs.get("auth_type", "password")
auth_type_username = source.job_inputs.get("user")
auth_type_password = source.job_inputs.get("password")
auth_type_passphrase = source.job_inputs.get("passphrase")
auth_type_private_key = source.job_inputs.get("private_key")

sql_schemas = get_snowflake_schemas(
account_id=account_id,
database=database,
warehouse=warehouse,
user=auth_type_username,
password=auth_type_password,
schema=sf_schema,
role=role,
auth_type=auth_type,
passphrase=auth_type_passphrase,
private_key=auth_type_private_key,
)

schemas_to_sync = list(sql_schemas.keys())
else:
Expand Down
13 changes: 11 additions & 2 deletions posthog/warehouse/api/external_data_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,23 @@ def incremental_fields(self, request: Request, *args: Any, **kwargs: Any):
sf_schema = source.job_inputs.get("schema")
role = source.job_inputs.get("role")

auth_type = source.job_inputs.get("auth_type", "password")
auth_type_username = source.job_inputs.get("user")
auth_type_password = source.job_inputs.get("password")
auth_type_passphrase = source.job_inputs.get("passphrase")
auth_type_private_key = source.job_inputs.get("private_key")

sf_schemas = get_snowflake_schemas(
account_id=account_id,
database=database,
warehouse=warehouse,
user=user,
password=password,
user=auth_type_username,
password=auth_type_password,
schema=sf_schema,
role=role,
auth_type=auth_type,
passphrase=auth_type_passphrase,
private_key=auth_type_private_key,
)

columns = sf_schemas.get(instance.name, [])
Expand Down
Loading
Loading