From 65e29f511c81c8030edb0f8b4f072cb074a5bd74 Mon Sep 17 00:00:00 2001 From: Neha Rane Date: Fri, 26 Apr 2024 20:55:03 +0530 Subject: [PATCH] accepting query_tag in connection arguments --- src/snowflake/sqlalchemy/_constants.py | 3 +++ src/snowflake/sqlalchemy/snowdialect.py | 20 ++++++-------------- src/snowflake/sqlalchemy/util.py | 10 ++++++++++ 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 46af4454..9fc80afb 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -10,3 +10,6 @@ APPLICATION_NAME = "SnowflakeSQLAlchemy" SNOWFLAKE_SQLALCHEMY_VERSION = VERSION + +PARAM_QUERY_TAG = "query_tag" +PARAM_SESSION_PARAMETERS = "session_parameters" \ No newline at end of file diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 2e40d03c..bcd34b25 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -63,7 +63,7 @@ _CUSTOM_Float, _CUSTOM_Time, ) -from .util import _update_connection_application_name, parse_url_boolean +from .util import _update_connection_application_name, parse_url_boolean, _update_connection_session_parameters colspecs = { Date: _CUSTOM_Date, @@ -110,7 +110,7 @@ "GEOMETRY": GEOMETRY, } -_ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True +_CUSTOMIZE_APPLICATION_NAME_AND_SESSION_PARAMETERS = True class SnowflakeDialect(default.DefaultDialect): @@ -888,18 +888,10 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): } def connect(self, *cargs, **cparams): - return ( - super().connect( - *cargs, - **( - _update_connection_application_name(**cparams) - if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME - else cparams - ), - ) - if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME - else super().connect(*cargs, **cparams) - ) + if _CUSTOMIZE_APPLICATION_NAME_AND_SESSION_PARAMETERS: + cparams = _update_connection_application_name(**cparams) + cparams = _update_connection_session_parameters(**cparams) + return super().connect(*cargs, **cparams) @sa_vnt.listens_for(Table, "before_create") diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 32e07373..0282a3f0 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -26,6 +26,8 @@ PARAM_INTERNAL_APPLICATION_NAME, PARAM_INTERNAL_APPLICATION_VERSION, SNOWFLAKE_SQLALCHEMY_VERSION, + PARAM_QUERY_TAG, + PARAM_SESSION_PARAMETERS, ) @@ -115,6 +117,14 @@ def _update_connection_application_name(**conn_kwargs: Any) -> Any: return conn_kwargs +def _update_connection_session_parameters(**conn_kwargs: Any) -> Any: + if PARAM_QUERY_TAG in conn_kwargs: + session_parameters = {} + session_parameters.update({"QUERY_TAG": conn_kwargs[PARAM_QUERY_TAG]}) + conn_kwargs[PARAM_SESSION_PARAMETERS] = session_parameters + return conn_kwargs + + def parse_url_boolean(value: str) -> bool: if value == "True": return True