Skip to content

Commit

Permalink
Change snowflake URI format to adhere to the SQLAlchemy convention
Browse files Browse the repository at this point in the history
  • Loading branch information
erezsh committed Jun 10, 2022
1 parent ef434f7 commit 2873d18
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions preql/sql_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,19 +439,20 @@ class SnowflakeInterface(SqlInterface):
id_type_decl = "number autoincrement"
max_rows_per_query = 16384

def __init__(self, account, user, password, path, schema, database, print_sql=False):
def __init__(self, account, user, password, warehouse, schema, database, role=None, print_sql=False):
import logging
logging.getLogger('snowflake.connector').setLevel(logging.WARNING)
import snowflake.connector

self._client = snowflake.connector.connect(
user=user,
password=password,
account=account
)
self._client.cursor().execute(f"USE WAREHOUSE {path.lstrip('/')}")
self._client.cursor().execute(f"USE DATABASE {database}")
self._client.cursor().execute(f"USE SCHEMA {schema}")
account=account,
role=role,
database=database,
warehouse=warehouse,
schema=schema,
)

self._print_sql = print_sql

Expand Down Expand Up @@ -909,6 +910,16 @@ def create_engine(db_uri, print_sql, auto_create):
raise NotImplementedError("Preql doesn't support multiple schemes")
scheme ,= dsn.schemes

if scheme == "snowflake":
database, schema = dsn.paths
try:
warehouse = dsn.query["warehouse"]
except KeyError:
raise ValueError(
"Must provide warehouse. Format: 'snowflake://<user>:<pass>@<account>/<database>/<schema>?warehouse=<warehouse>'"
)
return SnowflakeInterface(dsn.host, dsn.user, dsn.password, warehouse=warehouse, database=database, schema=schema, print_sql=print_sql)

if len(dsn.paths) == 0:
path = ''
elif len(dsn.paths) == 1:
Expand All @@ -926,8 +937,6 @@ def create_engine(db_uri, print_sql, auto_create):
return DuckInterface(path, print_sql=print_sql)
elif scheme == 'bigquery':
return BigQueryInterface(path, print_sql=print_sql)
elif scheme == 'snowflake':
return SnowflakeInterface(dsn.host, dsn.user, dsn.password, path, **dsn.query, print_sql=print_sql)
elif scheme == 'redshift':
return RedshiftInterface(dsn.host, dsn.port, path, dsn.user, dsn.password, print_sql=print_sql)
elif scheme == 'oracle':
Expand Down

0 comments on commit 2873d18

Please sign in to comment.