Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PECO-1436] [sqlalchemy] Include sqlalchemy __version__ in user-agent #332

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
56 changes: 30 additions & 26 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from typing import Any, List, Optional, Dict, Union
import re
from typing import Any, Dict, List, Optional, Union

import sqlalchemy
from sqlalchemy import DDL, event
from sqlalchemy.engine import Connection, Engine, default, reflection
from sqlalchemy.engine.interfaces import (
ReflectedColumn,
ReflectedForeignKeyConstraint,
ReflectedPrimaryKeyConstraint,
ReflectedTableComment,
)
from sqlalchemy.engine.reflection import ReflectionDefaults
from sqlalchemy.exc import DatabaseError, SQLAlchemyError

import databricks.sqlalchemy._ddl as dialect_ddl_impl
import databricks.sqlalchemy._types as dialect_type_impl
Expand All @@ -8,24 +21,12 @@
_match_table_not_found_string,
build_fk_dict,
build_pk_dict,
get_comment_from_dte_output,
get_fk_strings_from_dte_output,
get_pk_strings_from_dte_output,
get_comment_from_dte_output,
parse_column_info_from_tgetcolumnsresponse,
)

import sqlalchemy
from sqlalchemy import DDL, event
from sqlalchemy.engine import Connection, Engine, default, reflection
from sqlalchemy.engine.interfaces import (
ReflectedForeignKeyConstraint,
ReflectedPrimaryKeyConstraint,
ReflectedColumn,
ReflectedTableComment,
)
from sqlalchemy.engine.reflection import ReflectionDefaults
from sqlalchemy.exc import DatabaseError, SQLAlchemyError

try:
import alembic
except ImportError:
Expand Down Expand Up @@ -401,6 +402,21 @@ def get_table_comment(
return ReflectionDefaults.table_comment()


SQLALCHEMY_TAG = f"sqlalchemy/{sqlalchemy.__version__}"
sqlalchemy_version_tag_pat = r"sqlalchemy/(\d+\.\d+\.\d+)"


def add_sqla_tag_if_not_present(val: Optional[str] = None):
if val is None or val == "":
output = SQLALCHEMY_TAG
elif re.search(sqlalchemy_version_tag_pat, val):
output = val
else:
output = f"{SQLALCHEMY_TAG} + {val}"

return output


@event.listens_for(Engine, "do_connect")
def receive_do_connect(dialect, conn_rec, cargs, cparams):
"""Helpful for DS on traffic from clients using SQLAlchemy in particular"""
Expand All @@ -411,18 +427,6 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams):

ua = cparams.get("_user_agent_entry", "")

def add_sqla_tag_if_not_present(val: str):
if not val:
output = "sqlalchemy"

if val and "sqlalchemy" in val:
output = val

else:
output = f"sqlalchemy + {val}"

return output

cparams["_user_agent_entry"] = add_sqla_tag_if_not_present(ua)

if sqlalchemy.__version__.startswith("1.3"):
Expand Down
101 changes: 75 additions & 26 deletions src/databricks/sqlalchemy/test_local/e2e/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import datetime
import decimal
import re
from typing import Tuple, Union, List
from unittest import skipIf

import pytest
import sqlalchemy
from sqlalchemy import (
Column,
MetaData,
Expand All @@ -20,6 +22,12 @@
from sqlalchemy.schema import DropColumnComment, SetColumnComment
from sqlalchemy.types import BOOLEAN, DECIMAL, Date, Integer, String

from databricks.sqlalchemy.base import (
SQLALCHEMY_TAG,
add_sqla_tag_if_not_present,
sqlalchemy_version_tag_pat,
)

try:
from sqlalchemy.orm import declarative_base
except ImportError:
Expand Down Expand Up @@ -120,20 +128,6 @@ def test_can_connect(db_engine):
assert len(result) == 1


def test_connect_args(db_engine):
"""Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI

This will most commonly happen when partners supply a user agent entry
"""

conn = db_engine.connect()
connection_headers = conn.connection.thrift_backend._transport._headers
user_agent = connection_headers["User-Agent"]

expected = f"(sqlalchemy + {USER_AGENT_TOKEN})"
assert expected in user_agent


@pytest.mark.skipif(sqlalchemy_1_3(), reason="Pandas requires SQLAlchemy >= 1.4")
@pytest.mark.skip(
reason="DBR is currently limited to 256 parameters per call to .execute(). Test cannot pass."
Expand Down Expand Up @@ -448,25 +442,80 @@ def test_has_table_across_schemas(
conn.execute(text("DROP TABLE test_has_table;"))


def test_user_agent_adjustment(db_engine):
# If .connect() is called multiple times on an engine, don't keep pre-pending the user agent
# https://github.com/databricks/databricks-sql-python/issues/192
c1 = db_engine.connect()
c2 = db_engine.connect()
class TestUserAgent:
@pytest.fixture(scope="class")
def expected_sqlalchemy_tag(self):
import sqlalchemy

user_agent_tag = f"sqlalchemy/{sqlalchemy.__version__}"
return user_agent_tag

def get_conn_user_agent(conn):
def get_conn_user_agent(self, conn):
return conn.connection.dbapi_connection.thrift_backend._transport._headers.get(
"User-Agent"
)

ua1 = get_conn_user_agent(c1)
ua2 = get_conn_user_agent(c2)
same_ua = ua1 == ua2
def test_user_agent_adjustment(self, db_engine):
# If .connect() is called multiple times on an engine, don't keep pre-pending the user agent
# https://github.com/databricks/databricks-sql-python/issues/192
c1 = db_engine.connect()
c2 = db_engine.connect()

ua1 = self.get_conn_user_agent(c1)
ua2 = self.get_conn_user_agent(c2)
same_ua = ua1 == ua2

c1.close()
c2.close()

c1.close()
c2.close()
assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}"

assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}"
def test_sqlalchemy_user_agent_includes_version(self, db_engine):
"""So that we know when we can safely deprecate support for sqlalchemy 1.x"""

version_str = sqlalchemy.__version__
c = db_engine.connect()
ua = self.get_conn_user_agent(c)

assert version_str in ua

def test_user_supplied_string(self, db_engine):
"""Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI

This will most commonly happen when partners supply a user agent entry
"""

conn = db_engine.connect()
connection_headers = conn.connection.thrift_backend._transport._headers
user_agent = connection_headers["User-Agent"]

assert USER_AGENT_TOKEN in user_agent

@pytest.mark.parametrize(
"input, expected",
(
(None, "{}"),
("", "{}"),
("sqlalchemy connection", "{} + sqlalchemy connection"),
(
"reusable dialect compliance tests",
"{} + reusable dialect compliance tests",
),
),
)
def test_user_agent_insertion_behavior(
self, input: Union[str, None], expected: str, expected_sqlalchemy_tag: str
):
assert add_sqla_tag_if_not_present(input) == expected.format(
expected_sqlalchemy_tag
)

@pytest.mark.parametrize(
"input",
("sqlalchemy/1.4.0", "sqlalchemy/1.3.0", "sqlalchemy/2.0.0", SQLALCHEMY_TAG),
)
def test_sqlalchemy_tag_regexes_properly(self, input):
assert re.search(sqlalchemy_version_tag_pat, input)


@pytest.fixture
Expand Down
Loading