diff --git a/preql/sql_interface.py b/preql/sql_interface.py index b51df71..89a66c3 100644 --- a/preql/sql_interface.py +++ b/preql/sql_interface.py @@ -439,7 +439,7 @@ class SnowflakeInterface(SqlInterface): id_type_decl = "number autoincrement" max_rows_per_query = 16384 - def __init__(self, account, user, password, path, schema, database, print_sql=False): + def __init__(self, account, user, password, warehouse, schema, database, role=None, print_sql=False): import logging logging.getLogger('snowflake.connector').setLevel(logging.WARNING) import snowflake.connector @@ -447,11 +447,12 @@ def __init__(self, account, user, password, path, schema, database, print_sql=Fa self._client = snowflake.connector.connect( user=user, password=password, - account=account - ) - self._client.cursor().execute(f"USE WAREHOUSE {path.lstrip('/')}") - self._client.cursor().execute(f"USE DATABASE {database}") - self._client.cursor().execute(f"USE SCHEMA {schema}") + account=account, + role=role, + database=database, + warehouse=warehouse, + schema=schema, + ) self._print_sql = print_sql @@ -909,6 +910,16 @@ def create_engine(db_uri, print_sql, auto_create): raise NotImplementedError("Preql doesn't support multiple schemes") scheme ,= dsn.schemes + if scheme == "snowflake": + database, schema = dsn.paths + try: + warehouse = dsn.query["warehouse"] + except KeyError: + raise ValueError( + "Must provide warehouse. Format: 'snowflake://:@//?warehouse='" + ) + return SnowflakeInterface(dsn.host, dsn.user, dsn.password, warehouse=warehouse, database=database, schema=schema, print_sql=print_sql) + if len(dsn.paths) == 0: path = '' elif len(dsn.paths) == 1: @@ -926,8 +937,6 @@ def create_engine(db_uri, print_sql, auto_create): return DuckInterface(path, print_sql=print_sql) elif scheme == 'bigquery': return BigQueryInterface(path, print_sql=print_sql) - elif scheme == 'snowflake': - return SnowflakeInterface(dsn.host, dsn.user, dsn.password, path, **dsn.query, print_sql=print_sql) elif scheme == 'redshift': return RedshiftInterface(dsn.host, dsn.port, path, dsn.user, dsn.password, print_sql=print_sql) elif scheme == 'oracle':