From 5a4fdfcb580d8eda8ebedbeb1e1f5e86a7d77240 Mon Sep 17 00:00:00 2001 From: Chanukya Konuganti Date: Thu, 15 Feb 2024 11:58:58 -0800 Subject: [PATCH] added fix for snowflake operator to accept authenticator and make parameters optional based on load_type (#92) * Making parameters like sf_cluster_keys, sf_grantee_roles optional, improving OO code , fixing authenticator issue * reverting parameter name change to avoid compatibility issues --- brickflow/engine/task.py | 1 + .../databricks/uc_to_snowflake_operator.py | 176 +++++++++++------- docs/faq/faq.md | 7 +- docs/tasks.md | 26 ++- .../brickflow_examples/workflows/demo_wf.py | 7 +- .../dev_bundle_polyrepo_with_auto_libs.yml | 3 + tests/engine/test_task.py | 6 +- 7 files changed, 145 insertions(+), 81 deletions(-) diff --git a/brickflow/engine/task.py b/brickflow/engine/task.py index e0a1f110..929360a3 100644 --- a/brickflow/engine/task.py +++ b/brickflow/engine/task.py @@ -821,6 +821,7 @@ def get_brickflow_libraries(enable_plugins: bool = False) -> List[TaskLibrary]: return [ bf_lib, PypiTaskLibrary("apache-airflow==2.6.3"), + PypiTaskLibrary("snowflake==0.5.1"), MavenTaskLibrary("com.cronutils:cron-utils:9.2.0"), ] else: diff --git a/brickflow_plugins/databricks/uc_to_snowflake_operator.py b/brickflow_plugins/databricks/uc_to_snowflake_operator.py index 18c7b813..f67ff545 100644 --- a/brickflow_plugins/databricks/uc_to_snowflake_operator.py +++ b/brickflow_plugins/databricks/uc_to_snowflake_operator.py @@ -1,4 +1,4 @@ -import logging +import logging as log try: import snowflake.connector @@ -21,6 +21,13 @@ def run_snowflake_queries(*args): """ ) +try: + from brickflow import ctx +except ImportError: + raise ImportError( + "plugin requires brickflow context , please install library at cluster/workflow/task level" + ) + class SnowflakeOperatorException(Exception): pass @@ -37,7 +44,7 @@ class SnowflakeOperator: Example Usage in your brickflow task SnowflakeOperator( secret_scope=databricks_secrets_psc - parameters= sf_load_parameters + query_string=queries separated by semicolon ) As databricks secrets is a key value store, code expects the secret scope to contain the below exact keys @@ -53,27 +60,31 @@ class SnowflakeOperator: above code snippet expects the data as follows databricks_secrets_psc contains username, password, account, warehouse, database and role keys with snowflake values - sf_load_parameters = {'query': comma_separeted_list_of_queries} - + query_string : required parameter with queries separeted by semicolon(;) + parameters: optional parameter dictionary with key value pairs to substitute in the query """ - def __init__(self, secret_scope, parameters={}, *args, **kwargs): + def __init__(self, secret_scope, query_string, parameters={}, *args, **kwargs): self.cur = None self.query = None - self.parameters = parameters self.secret_scope = secret_scope - self.log = logging - self.query = parameters.get("query") or None - self.authenticator = None + self.log = log + self.query = query_string + self.parameters = parameters if not self.secret_scope: raise ValueError( "Must provide reference to Snowflake connection in databricks secretes !" ) + try: - import base64 - from brickflow import ctx + self.authenticator = ctx.dbutils.secrets.get( + self.secret_scope, "authenticator" + ) + except: + self.authenticator = None + try: self.username = ctx.dbutils.secrets.get(self.secret_scope, "username") self.password = ctx.dbutils.secrets.get(self.secret_scope, "password") self.account = ctx.dbutils.secrets.get(self.secret_scope, "account") @@ -90,8 +101,8 @@ def get_snowflake_connection(self): """ logic to connect to snowflake instance with provided details and return a connection object """ - if self.authenticator: - print( + if self.authenticator is not None: + self.log.info( "snowflake_account_name={0}, database={1}, username={2}, warehouse={3}, role={4}, authenticator={5}".format( self.account, self.database, @@ -111,7 +122,7 @@ def get_snowflake_connection(self): authenticator=self.authenticator, ) else: - print( + self.log.info( "snowflake_account_name={0}, database={1}, username={2}, warehouse={3}, role={4}".format( self.account, self.database, @@ -128,6 +139,7 @@ def get_snowflake_connection(self): database=self.database, role=self.role, ) + self.parameters.update( { "account_name": self.account, @@ -137,6 +149,7 @@ def get_snowflake_connection(self): "role": self.role, } ) + return con def get_cursor(self): @@ -144,12 +157,13 @@ def get_cursor(self): logic to create a cursor for a successful snowflake connection to execute queries """ try: - # self.log.info('getting connection for secret scope id {}'.format(self.secret_scope)) + self.log.info( + "establishing connection for secret scope id {}".format( + self.secret_scope + ) + ) con = self.get_snowflake_connection() except snowflake.connector.errors.ProgrammingError as e: - # default error message - # self.log.warning(e) - # customer error message raise ValueError( "Error {0} ({1}): {2} ({3})".format(e.errno, e.sqlstate, e.msg, e.sfqid) ) @@ -198,7 +212,6 @@ def snowflake_query_exec(self, cur, database, query_string): e.errno, e.sqlstate, e.msg, e.sfqid ) ) - self.log.info("Query completed successfully") self.log.info("All Query/Queries completed successfully") @@ -212,6 +225,8 @@ def execute(self): # Run the query against SnowFlake try: self.snowflake_query_exec(self.cur, self.database, query_string) + except: + self.log.error("failed to execute") finally: self.cur.close() self.log.info("Closed connection") @@ -245,26 +260,33 @@ class UcToSnowflakeOperator(SnowflakeOperator): warehouse: warehouse/cluster information that user has access for ex: sample_warehouse database : default database that we want to connect for ex: sample_database role : role to which the user has write access for ex: sample_write_role + Authenticator: optional additional authenticator needed for connection for ex: okta_connection_url above code snippet expects the data as follows databricks_secrets_psc contains username, password, account, warehouse, database and role keys with snowflake values - uc_parameters = {'load_type':'incremental','dbx_catalog':'sample_catalog','dbx_database':'sample_schema', + parameters = {'load_type':'incremental','dbx_catalog':'sample_catalog','dbx_database':'sample_schema', 'dbx_table':'sf_operator_1', 'sf_schema':'stage','sf_table':'SF_OPERATOR_1', 'sf_grantee_roles':'downstream_read_role', 'incremental_filter':"dt='2023-10-22'", 'sf_cluster_keys':''} + + in the parameters dictionary we have mandatory keys as follows + load_type(required): incremental/full + dbx_catalog (required): name of the catalog in unity + dbx_database (required): schema name within the catalog + dbx_table (required): name of the object in the schema + sf_database (optional): database name in snowflake + sf_schema (required): snowflake schema in the database provided as part of scope + sf_table (required): name of the table in snowflake to which we want to append or overwrite + incremental_filter (optional): mandatory parameter for incremental load type to delete existing data in snowflake table + dbx_data_filter (optional): parameter to filter databricks table if different from snowflake filter + sf_cluster_keys (optional): list of keys to cluster the data in snowflake """ def __init__(self, secret_scope, parameters={}, *args, **kwargs): - self.secret_scope = secret_scope - self.sf_pre_steps_sql = None - self.sf_post_steps_sql = None - self.sf_post_grants_sql = None - self.cur = None - self.conn = None - self.log = logging - self.parameters = parameters + SnowflakeOperator.__init__(self, secret_scope, "", parameters) + self.dbx_data_filter = self.parameters.get("dbx_data_filter") or None self.write_mode = None - self.sf_cluster_keys = None + """ self.authenticator = None try: import base64 @@ -281,31 +303,41 @@ def __init__(self, secret_scope, parameters={}, *args, **kwargs): "Failed to fetch details from secret scope for username, password, account, warehouse, \ database, role !" ) + """ def get_sf_presteps(self): queries = """ - CREATE OR REPLACE TABLE {sfSchema}.{sfTable_clone} CLONE {sfSchema}.{sfTable}; - DELETE FROM {sfSchema}.{sfTable_clone} WHERE {incremental_filter}""".format( + CREATE OR REPLACE TABLE {sfDatabase}.{sfSchema}.{sfTable_clone} CLONE {sfDatabase}.{sfSchema}.{sfTable}; + DELETE FROM {sfDatabase}.{sfSchema}.{sfTable_clone} WHERE {data_filter}""".format( sfSchema=self.parameters["sf_schema"], sfTable_clone=self.parameters["sf_table"] + "_clone", sfTable=self.parameters["sf_table"], - incremental_filter=self.parameters["incremental_filter"], + sfDatabase=self.sf_database, + data_filter=self.parameters["incremental_filter"], ) return queries def get_sf_poststeps(self): - queries = """ ALTER TABLE {sfSchema}.{sfTable_clone} SWAP WITH {sfSchema}.{sfTable}; DROP TABLE {sfSchema}.{sfTable_clone} """.format( + queries = """ ALTER TABLE {sfDatabase}.{sfSchema}.{sfTable_clone} SWAP WITH {sfDatabase}.{sfSchema}.{sfTable}; + DROP TABLE {sfDatabase}.{sfSchema}.{sfTable_clone} """.format( sfSchema=self.parameters["sf_schema"], sfTable_clone=self.parameters["sf_table"] + "_clone", sfTable=self.parameters["sf_table"], + sfDatabase=self.sf_database, ) return queries def get_sf_postgrants(self): - queries = """ GRANT SELECT ON TABLE {sfSchema}.{sfTable} TO ROLE {sfGrantee_roles};""".format( + post_grantee_role = ( + self.parameters["sf_grantee_roles"] + if "sf_grantee_roles" in self.parameters.keys() + else self.role + ) + queries = """ GRANT SELECT ON TABLE {sfDatabase}.{sfSchema}.{sfTable} TO ROLE {sfGrantee_roles};""".format( sfSchema=self.parameters["sf_schema"], sfTable=self.parameters["sf_table"], - sfGrantee_roles=self.parameters["sf_grantee_roles"], + sfGrantee_roles=post_grantee_role, + sfDatabase=self.sf_database, ) return queries @@ -322,7 +354,6 @@ def validate_input_params(self): "dbx_table", "sf_schema", "sf_table", - "sf_grantee_roles", ) if not all(key in self.parameters for key in mandatory_keys): self.log.info( @@ -333,14 +364,25 @@ def validate_input_params(self): "Mandatory key(s) NOT exists in UcToSnowflakeOperator(parameters): %s\n" % format(self.parameters) ) - raise Exception("Job failed") + raise Exception("Job failed due to missing manadatory key") # Setting up pre,post and grants scripts for snowflake - self.sf_pre_steps_sql = self.get_sf_presteps() - self.sf_post_steps_sql = self.get_sf_poststeps() self.sf_post_grants_sql = self.get_sf_postgrants() + + if self.parameters["load_type"] == "incremental": + if "incremental_filter" not in self.parameters.keys(): + self.log.info( + "manadatory key incremntal_filter is missing for incremntal loads" + ) + self.log.error( + "manadatory key incremntal_filter is missing for incremntal loads" + ) + raise Exception("Job failed due to missing manadatory key") + self.sf_pre_steps_sql = self.get_sf_presteps() + self.sf_post_steps_sql = self.get_sf_poststeps() + else: self.log.error("Input is NOT a dictionary: %s\n" % format(self.parameters)) - raise Exception("Job failed") + raise Exception("Job failed due to missing manadatory key") def submit_job_snowflake(self, query_input): """ @@ -352,7 +394,7 @@ def submit_job_snowflake(self, query_input): self.get_cursor() query_string_list = str(query_input).strip().split(";") for query_string in query_string_list: - print(query_string) + self.log.info(query_string) self.snowflake_query_exec(self.cur, self.database, query_string.strip()) except Exception as e: @@ -366,37 +408,29 @@ def apply_grants(self): """ Function to apply grants after successful execution """ - grantee_roles = self.parameters.get("sf_grantee_roles") + grantee_roles = self.parameters.get("sf_grantee_roles") or self.role for grantee_role in grantee_roles.split(","): self.parameters.update({"sf_grantee_roles": grantee_role}) self.submit_job_snowflake(self.sf_post_grants_sql) def extract_source(self): - from brickflow import ctx - if self.parameters["load_type"] == "incremental": - dbx_incremental_filter = ( - self.parameters["dbx_incremental_filter"] - if "dbx_incremental_filter" in self.parameters.keys() - else self.parameters["incremental_filter"] + self.dbx_data_filter = ( + self.parameters.get("dbx_data_filter") + or self.parameters.get("incremental_filter") + or "1=1" ) - if dbx_incremental_filter: - df = ctx.spark.sql( - """select * from {}.{}.{} where {}""".format( - self.parameters["dbx_catalog"], - self.parameters["dbx_database"], - self.parameters["dbx_table"], - dbx_incremental_filter, - ) - ) else: - df = ctx.spark.sql( - """select * from {}.{}.{}""".format( - self.parameters["dbx_catalog"], - self.parameters["dbx_database"], - self.parameters["dbx_table"], - ) + self.dbx_data_filter = self.parameters.get("dbx_data_filter") or "1=1" + + df = ctx.spark.sql( + """select * from {}.{}.{} where {}""".format( + self.parameters["dbx_catalog"], + self.parameters["dbx_database"], + self.parameters["dbx_table"], + self.dbx_data_filter, ) + ) return df def load_snowflake(self, source_df, target_table): @@ -407,10 +441,12 @@ def load_snowflake(self, source_df, target_table): "sfUser": self.username, "sfPassword": self.password, "sfWarehouse": self.warehouse, - "sfDatabase": self.database, + "sfDatabase": self.sf_database, "sfSchema": self.parameters["sf_schema"], "sfRole": self.role, } + if self.authenticator is not None: + sf_options["sfAuthenticator"] = self.authenticator self.log.info("snowflake package and options defined...!!!") if len(source_df.take(1)) == 0: self.write_mode = "Append" @@ -419,7 +455,7 @@ def load_snowflake(self, source_df, target_table): source_df.write.format(sf_package).options(**sf_options).option( "dbtable", "{0}.{1}.{2}".format( - self.database, self.parameters["sf_schema"], target_table + self.sf_database, self.parameters["sf_schema"], target_table ), ).mode("{0}".format(self.write_mode)).save() @@ -430,7 +466,7 @@ def load_snowflake(self, source_df, target_table): ).option( "dbtable", "{0}.{1}.{2}".format( - self.database, self.parameters["sf_schema"], target_table + self.sf_database, self.parameters["sf_schema"], target_table ), ).mode( "{0}".format(self.write_mode) @@ -449,17 +485,19 @@ def submit_job_compute(self): ) self.sf_cluster_keys = ( [] - if self.parameters["sf_cluster_keys"] is None + if "sf_cluster_keys" not in self.parameters.keys() else self.parameters["sf_cluster_keys"] ) - self.log.info("loading data to snowflake") + self.log.info("loading data to snowflake table") self.load_snowflake(source_data, target_table) + self.log.info("successfully loaded data to snowflake table ") def execute(self): """ Main method for execution """ # Validate the input parameters + self.sf_database = self.parameters.get("sf_database") or self.database self.validate_input_params() # Identify the execution environment @@ -485,6 +523,6 @@ def execute(self): else: raise Exception( - "NOT a supported value for load_type: %s\n" + "NOT a supported value for load_type: %s \n please provide either full or incremental" % format(self.parameters.get("load_type")) ) diff --git a/docs/faq/faq.md b/docs/faq/faq.md index 00c20d48..34a87692 100644 --- a/docs/faq/faq.md +++ b/docs/faq/faq.md @@ -68,8 +68,9 @@ wf = Workflow(...) @wf.task def run_snowflake_queries(*args): sf_query_run = SnowflakeOperator( - secret_cope = "your_databricks secrets scope name", - input_params = {'query':"comma_seprated_list_of_queries"} + secret_scope = "your_databricks secrets scope name", + query_string = "string of queries separated by semicolon(;)", + parameters={"key1":"value1", "key2":"value2"} ) sf_query_run.execute() ``` @@ -85,7 +86,7 @@ wf = Workflow(...) def copy_from_uc_sf(*args): uc_to_sf_copy = UcToSnowflakeOperator( secret_scope = "your_databricks secrets scope name", - uc_parameters = {'load_type':'incremental','dbx_catalog':'sample_catalog','dbx_database':'sample_schema', + parameters = {'load_type':'incremental','dbx_catalog':'sample_catalog','dbx_database':'sample_schema', 'dbx_table':'sf_operator_1', 'sf_schema':'stage','sf_table':'SF_OPERATOR_1', 'sf_grantee_roles':'downstream_read_role', 'incremental_filter':"dt='2023-10-22'", 'sf_cluster_keys':''} diff --git a/docs/tasks.md b/docs/tasks.md index 02b5f931..22cbbeae 100644 --- a/docs/tasks.md +++ b/docs/tasks.md @@ -409,6 +409,11 @@ As databricks secrets is a key value store, code expects the secret scope to con     database : default database that we want to connect for ex: sample_database     role : role to which the user has write access for ex: sample_write_role +SnowflakeOperator can accept the following as inputs +    secret_scope (required): databricks secret scope identifier +    query_string (required): queries separated by semicolon +    parameters (optional) : dictionary with variables that can be used to substitute in queries + ```python title="snowflake_operator" from brickflow_plugins import SnowflakeOperator @@ -417,8 +422,9 @@ wf = Workflow(...) @wf.task def run_snowflake_queries(*args): sf_query_run = SnowflakeOperator( - secret_cope = "your_databricks secrets scope name", - input_params = {'query':"comma_seprated_list_of_queries"} + secret_scope = "your_databricks secrets scope name", + query_string ="select * from database.$schema.$table where $filter_condition1; select * from sample_schema.test_table", + parameters = {"schema":"test_schema","table":"sample_table","filter_condition":"col='something'"} ) sf_query_run.execute() ``` @@ -436,6 +442,18 @@ As databricks secrets is a key value store, code expects the secret scope to con     database : default database that we want to connect for ex: sample_database     role : role to which the user has write access for ex: sample_write_role +UcToSnowflakeOperator can expects the following as inputs to copy data in parameters +    load_type (required): type of data load , acceptable values full or incremental +    dbx_catalog (required) : name of the databricks catalog in which object resides +    dbx_database (required): name of the databricks schema in which object is available +    dbx_table (required) : name of the databricks object we want to copy to snowflake +    sf_database (optional) : name of the snowflake database if different from the one in secret_scope +    sf_schema (required): name of the snowflake schema in which we want to copy the data +    sf_table (required) : name of the snowflake object to which we want to copy from databricks +    incremental_filter (required for incrmental mode) : condition to manage data before writing to snowflake +    dbx_data_filter (optional): filter condition on databricks source for full or incremental (if different from inremental_filter) +    sf_grantee_roles (optional) : snowflake roles to which we want to grant select/read access +    sf_cluster_keys (optional) : list of keys we want to cluster our snowflake table. ```python title="uc_to_snowflake_operator" from brickflow_plugins import UcToSnowflakeOperator @@ -445,8 +463,8 @@ wf = Workflow(...) @wf.task def run_snowflake_queries(*args): uc_to_sf_copy = UcToSnowflakeOperator( - secret_cope = "your_databricks secrets scope name", - uc_parameters = {'load_type':'incremental','dbx_catalog':'sample_catalog','dbx_database':'sample_schema', + secret_scope = "your_databricks secrets scope name", + parameters = {'load_type':'incremental','dbx_catalog':'sample_catalog','dbx_database':'sample_schema', 'dbx_table':'sf_operator_1', 'sf_schema':'stage','sf_table':'SF_OPERATOR_1', 'sf_grantee_roles':'downstream_read_role', 'incremental_filter':"dt='2023-10-22'", 'sf_cluster_keys':''} diff --git a/examples/brickflow_examples/workflows/demo_wf.py b/examples/brickflow_examples/workflows/demo_wf.py index 44f34c46..e600bf76 100644 --- a/examples/brickflow_examples/workflows/demo_wf.py +++ b/examples/brickflow_examples/workflows/demo_wf.py @@ -273,7 +273,7 @@ def airflow_autosys_sensor(): def run_snowflake_queries(*args): uc_to_sf_copy = UcToSnowflakeOperator( secret_cope="sample_scope", - uc_parameters={ + parameters={ "load_type": "incremental", "dbx_catalog": "sample_catalog", "dbx_database": "sample_schema", @@ -282,6 +282,7 @@ def run_snowflake_queries(*args): "sf_table": "SF_OPERATOR_1", "sf_grantee_roles": "downstream_read_role", "incremental_filter": "dt='2023-10-22'", + "dbx_data_filter": "run_dt='2023-10-21'", "sf_cluster_keys": "", }, ) @@ -291,7 +292,9 @@ def run_snowflake_queries(*args): @wf.task def run_snowflake_queries(*args): sf_query_run = SnowflakeOperator( - secret_cope="sample_scope", input_params={"query": "select * from table"} + secret_cope="sample_scope", + query_string="select * from table; insert into table1 select * from $database.table2", + parameters={"database": "sample_db"}, ) sf_query_run.execute() diff --git a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml index 6c6b6f96..fd4b8931 100644 --- a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml +++ b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml @@ -34,6 +34,9 @@ targets: - pypi: package: apache-airflow==2.6.3 repo: null + - pypi: + package: snowflake==0.5.1 + repo: null - maven: coordinates: com.cronutils:cron-utils:9.2.0 exclusions: null diff --git a/tests/engine/test_task.py b/tests/engine/test_task.py index 593fad8c..9c3e54a6 100644 --- a/tests/engine/test_task.py +++ b/tests/engine/test_task.py @@ -423,7 +423,7 @@ def test_get_brickflow_lib_version(self): def test_get_brickflow_libraries(self): settings = BrickflowProjectDeploymentSettings() settings.brickflow_project_runtime_version = "1.0.0" - assert len(get_brickflow_libraries(enable_plugins=True)) == 3 + assert len(get_brickflow_libraries(enable_plugins=True)) == 4 assert len(get_brickflow_libraries(enable_plugins=False)) == 1 lib = get_brickflow_libraries(enable_plugins=False)[0].dict expected = { @@ -439,7 +439,7 @@ def test_get_brickflow_libraries_semver_non_numeric(self): settings = BrickflowProjectDeploymentSettings() tag = "1.0.1rc1234" settings.brickflow_project_runtime_version = tag - assert len(get_brickflow_libraries(enable_plugins=True)) == 3 + assert len(get_brickflow_libraries(enable_plugins=True)) == 4 assert len(get_brickflow_libraries(enable_plugins=False)) == 1 lib = get_brickflow_libraries(enable_plugins=False)[0].dict expected = { @@ -455,7 +455,7 @@ def test_get_brickflow_libraries_non_semver(self): settings = BrickflowProjectDeploymentSettings() tag = "somebranch" settings.brickflow_project_runtime_version = tag - assert len(get_brickflow_libraries(enable_plugins=True)) == 3 + assert len(get_brickflow_libraries(enable_plugins=True)) == 4 assert len(get_brickflow_libraries(enable_plugins=False)) == 1 lib = get_brickflow_libraries(enable_plugins=False)[0].dict expected = {