diff --git a/README.md b/README.md index 52294e9..2495de3 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ To connect to Solr with SQLAlchemy, the following URL pattern can be used: solr://:@:/solr/[?parameter=value] ``` +_Note_: port 8983 is used when `port` in the URL is omitted + ### Authentication #### Basic Authentication @@ -162,6 +164,7 @@ translates to `[2024-01-01T00:00:00Z TO *]` | Aliases | ✗ | ✗ | ✓ | ✓ | ✓ | ✓ | | Built-in date range compilation | ✗ | ✗ | ✗ | ✗ | ✗ | ✓ | | `SELECT` _expression_ statements | ✗ | ✓ | ✓ | ✓ | ✓ | ✓ | +| SQL compilation caching | ✗ | ✗ | ✗ | ✗ | ✗ | ✓ | ## Use Cases diff --git a/src/sqlalchemy_solr/admin/solr_spec.py b/src/sqlalchemy_solr/admin/solr_spec.py index 7c6a61f..768db90 100644 --- a/src/sqlalchemy_solr/admin/solr_spec.py +++ b/src/sqlalchemy_solr/admin/solr_spec.py @@ -1,15 +1,51 @@ from requests import Session +from sqlalchemy_solr import defaults class SolrSpec: _spec = None - def __init__(self, solr_base_url): + def __init__(self, url): + """ + Initializes a SolrSpec object + + :param url: Solr base url which can be a string HTTP(S) URL or a sqlalchemy.engine.url.URL. + """ + session = Session() + + if isinstance(url, str): + base_url = url + else: + if "verify_ssl" in url.query and url.query["verify_ssl"] in [ + "False", + "false", + ]: + session.verify = False + + token = None + if "token" in url.query: + token = url.query["token"] + + if token is not None: + session.headers.update({"Authorization": f"Bearer {token}"}) + else: + session.auth = (url.username, url.password) + + proto = "http" + if "use_ssl" in url.query and url.query["use_ssl"] in ["True", "true"]: + proto = "https" + + server_path = url.database.split("/")[0] + + port = url.port or defaults.PORT + base_url = f"{proto}://{url.host}:{port}/{server_path}" + sys_info_response = session.get( - solr_base_url + "/admin/info/system", params={"wt": "json"} + base_url + "/admin/info/system", params={"wt": "json"} ) + spec_version = sys_info_response.json()["lucene"]["solr-spec-version"] self._spec = list(map(int, spec_version.split("."))) diff --git a/src/sqlalchemy_solr/base.py b/src/sqlalchemy_solr/base.py index 607dd12..c6cc3ad 100644 --- a/src/sqlalchemy_solr/base.py +++ b/src/sqlalchemy_solr/base.py @@ -30,10 +30,10 @@ from sqlalchemy.sql import expression from sqlalchemy.sql import operators from sqlalchemy.sql.expression import BindParameter +from sqlalchemy_solr import release_flags from . import solrdbapi as module from .solr_type_compiler import SolrTypeCompiler -from .solrdbapi import Connection from .type_map import metadata_type_map logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.ERROR) @@ -42,8 +42,6 @@ class SolrCompiler(compiler.SQLCompiler): # pylint: disable=abstract-method - SOLR_DATE_RANGE_TRANS_RELEASE = 9 - merge_ops = (operators.ge, operators.gt, operators.le, operators.lt) bounds = { operators.ge: "[", @@ -70,7 +68,10 @@ def visit_binary( ): # Handled in Solr 9 - if Connection.solr_spec.spec()[0] >= self.SOLR_DATE_RANGE_TRANS_RELEASE: + if ( + SolrDialect.solr_spec.spec()[0] + >= release_flags.SOLR_DATE_RANGE_TRANS_RELEASE + ): return super().visit_binary(binary, override_operator, eager_grouping, **kw) if binary.operator not in self.merge_ops: @@ -157,7 +158,10 @@ def visit_binary( def visit_clauselist(self, clauselist, **kw): # Handled in Solr 9 - if Connection.solr_spec.spec()[0] >= self.SOLR_DATE_RANGE_TRANS_RELEASE: + if ( + SolrDialect.solr_spec.spec()[0] + >= release_flags.SOLR_DATE_RANGE_TRANS_RELEASE + ): return super().visit_clauselist(clauselist, **kw) if clauselist.operator == operators.and_: @@ -537,6 +541,8 @@ class SolrDialect(default.DefaultDialect): supports_native_boolean = True supports_statement_cache = True + solr_spec = None + def __init__(self, **kw): default.DefaultDialect.__init__(self, **kw) self.supported_extensions = [] diff --git a/src/sqlalchemy_solr/defaults.py b/src/sqlalchemy_solr/defaults.py new file mode 100644 index 0000000..a5af90d --- /dev/null +++ b/src/sqlalchemy_solr/defaults.py @@ -0,0 +1 @@ +PORT = 8983 diff --git a/src/sqlalchemy_solr/http.py b/src/sqlalchemy_solr/http.py index 35a76ed..c5b6d14 100755 --- a/src/sqlalchemy_solr/http.py +++ b/src/sqlalchemy_solr/http.py @@ -23,6 +23,9 @@ from requests import RequestException from requests import Session +from sqlalchemy_solr import defaults +from sqlalchemy_solr import release_flags +from sqlalchemy_solr.admin.solr_spec import SolrSpec from sqlalchemy_solr.solrdbapi.api_exceptions import DatabaseError from .api_globals import _HEADER @@ -86,7 +89,7 @@ def create_connect_args(self, url): # Save this for later use. self.host = url.host - self.port = url_port + self.port = url.port or defaults.PORT self.username = url.username self.password = url.password self.db = db @@ -96,11 +99,7 @@ def create_connect_args(self, url): # Prepare a session with proper authorization handling. session = Session() # session.verify property which is bydefault true so Handled here - if "verify_ssl" in url.query and url.query["verify_ssl"] in [ - False, - "False", - "false", - ]: + if "verify_ssl" in url.query and url.query["verify_ssl"] in ["False", "false"]: session.verify = False if self.token is not None: @@ -209,3 +208,22 @@ def get_unique_columns(self, columns): columns_set.remove(c["name"]) return unique_columns + + def on_connect_url(self, url): + SolrDialect.solr_spec = SolrSpec(url) + + def do_on_connect(connection): # pylint: disable=unused-argument + SolrDialect.solr_spec = SolrSpec(url) + + if ( + SolrDialect.solr_spec.spec()[0] + < release_flags.SOLR_DATE_RANGE_TRANS_RELEASE + ): + logging.warning( + "Solr version %s less than 9, SQL compilation cache disabled", + SolrDialect.solr_spec.spec()[0], + ) + SolrDialect_http.supports_statement_cache = False + SolrDialect.supports_statement_cache = False + + return do_on_connect diff --git a/src/sqlalchemy_solr/release_flags.py b/src/sqlalchemy_solr/release_flags.py new file mode 100644 index 0000000..25b0c52 --- /dev/null +++ b/src/sqlalchemy_solr/release_flags.py @@ -0,0 +1 @@ +SOLR_DATE_RANGE_TRANS_RELEASE = 9 diff --git a/src/sqlalchemy_solr/solrdbapi/_solrdbapi.py b/src/sqlalchemy_solr/solrdbapi/_solrdbapi.py index 26eef88..78b6f26 100755 --- a/src/sqlalchemy_solr/solrdbapi/_solrdbapi.py +++ b/src/sqlalchemy_solr/solrdbapi/_solrdbapi.py @@ -1,9 +1,9 @@ import logging from requests import Session +from sqlalchemy_solr import defaults from .. import type_map -from ..admin.solr_spec import SolrSpec from ..api_globals import _HEADER from ..api_globals import _PAYLOAD from ..message_formatter import MessageFormatter @@ -286,7 +286,6 @@ def __iter__(self): class Connection: # pylint: disable=too-many-instance-attributes - solr_spec = None mf = MessageFormatter() # pylint: disable=too-many-arguments @@ -314,8 +313,6 @@ def __init__( self._session = session self._connected = True - Connection.solr_spec = SolrSpec(f"{proto}{host}:{port}/{server_path}") - SolrTableReflection.connection = self @property @@ -374,23 +371,23 @@ def cursor(self): # pylint: disable=too-many-arguments def connect( host, - port=8047, - db=None, + db, + server_path, + collection, + port=defaults.PORT, username=None, password=None, - server_path="solr", - collection=None, - use_ssl=False, + use_ssl=None, verify_ssl=None, token=None, ): session = Session() # bydefault session.verify is set to True - if verify_ssl is not None and verify_ssl in [False, "False", "false"]: + if verify_ssl is not None and verify_ssl in ["False", "false"]: session.verify = False - if use_ssl in [True, "True", "true"]: + if use_ssl in ["True", "true"]: proto = "https://" else: proto = "http://" diff --git a/tests/assertions.py b/tests/assertions.py new file mode 100644 index 0000000..0faa7f2 --- /dev/null +++ b/tests/assertions.py @@ -0,0 +1,10 @@ +import pytest +from sqlalchemy_solr.admin.solr_spec import SolrSpec + + +def assert_solr_release(settings, releases): + solr_spec = SolrSpec(settings["SOLR_BASE_URL"]) + if solr_spec.spec()[0] not in releases: + pytest.skip( + reason=f"Solr spec version {solr_spec} not compatible with the current test" + ) diff --git a/tests/test_sql_compilation_caching.py b/tests/test_sql_compilation_caching.py deleted file mode 100644 index 15b07d4..0000000 --- a/tests/test_sql_compilation_caching.py +++ /dev/null @@ -1,58 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.sql.expression import bindparam -from sqlalchemy.util.langhelpers import _symbol -from tests.setup import prepare_orm - -from .fixtures.fixtures import SalesFixture - - -class TestSQLCompilationCaching: - def index_data(self, settings): - f = SalesFixture(settings) - f.truncate_collection() - f.index() - - def test_sql_compilation_caching_1(self, settings): - _, t = prepare_orm(settings) - - qry_1 = (select(t.c.CITY_s).select_from(t)).limit(1) - qry_2 = (select(t.c.CITY_s).select_from(t)).limit(10) - - k1 = qry_1._generate_cache_key() # pylint: disable=protected-access - k2 = qry_2._generate_cache_key() # pylint: disable=protected-access - - assert k1 == k2 - - def test_sql_compilation_caching_2(self, settings): - _, t = prepare_orm(settings) - - qry_1 = (select(t.c.CITY_s).select_from(t)).limit(1).offset(1) - qry_2 = (select(t.c.CITY_s).select_from(t)).limit(1).offset(2) - - k1 = qry_1._generate_cache_key() # pylint: disable=protected-access - k2 = qry_2._generate_cache_key() # pylint: disable=protected-access - - assert k1 == k2 - - def test_sql_compilation_caching_3(self, settings): - engine, t = prepare_orm(settings) - - qry = select(t).where(t.c.CITY_s == bindparam("CITY_s")).limit(10) - - with engine.connect() as connection: - result_1 = connection.execute(qry, {"CITY_s": "Singapore"}) - result_2 = connection.execute(qry, {"CITY_s": "Boras"}) - - assert result_1.context.cache_hit == _symbol("CACHE_MISS") - assert result_2.context.cache_hit == _symbol("CACHE_HIT") - - def test_sql_compilation_caching_4(self, settings): - _, t = prepare_orm(settings) - - qry_1 = select(t).where(t.c.CITY_s == bindparam("CITY_s")).limit(10) - qry_2 = select(t).where(t.c.COUNTRY_s == bindparam("COUNTRY_s")).limit(10) - - k1 = qry_1._generate_cache_key() # pylint: disable=protected-access - k2 = qry_2._generate_cache_key() # pylint: disable=protected-access - - assert k1 != k2 diff --git a/tests/test_sql_compilation_caching_6.py b/tests/test_sql_compilation_caching_6.py new file mode 100644 index 0000000..5cd68ac --- /dev/null +++ b/tests/test_sql_compilation_caching_6.py @@ -0,0 +1,54 @@ +from sqlalchemy import select +from sqlalchemy.sql.expression import bindparam +from sqlalchemy.util.langhelpers import _symbol +from tests import assertions +from tests.setup import prepare_orm + +releases = [6, 7, 8] + + +class TestSQLCompilationCaching: + + def test_sql_compilation_caching_1(self, settings): + assertions.assert_solr_release(settings, releases) + + engine, t = prepare_orm(settings) + + qry_1 = (select(t.c.COUNTRY_s).select_from(t)).limit(1) + qry_2 = (select(t.c.COUNTRY_s).select_from(t)).limit(10) + + with engine.connect() as connection: + result_1 = connection.execute(qry_1) + result_2 = connection.execute(qry_2) + + assert result_1.context.cache_hit == _symbol("NO_DIALECT_SUPPORT") + assert result_2.context.cache_hit == _symbol("NO_DIALECT_SUPPORT") + + def test_sql_compilation_caching_2(self, settings): + assertions.assert_solr_release(settings, releases) + + engine, t = prepare_orm(settings) + + qry_1 = (select(t.c.COUNTRY_s).select_from(t)).limit(1).offset(1) + qry_2 = (select(t.c.COUNTRY_s).select_from(t)).limit(1).offset(2) + + with engine.connect() as connection: + result_1 = connection.execute(qry_1) + result_2 = connection.execute(qry_2) + + assert result_1.context.cache_hit == _symbol("NO_DIALECT_SUPPORT") + assert result_2.context.cache_hit == _symbol("NO_DIALECT_SUPPORT") + + def test_sql_compilation_caching_3(self, settings): + assertions.assert_solr_release(settings, releases) + + engine, t = prepare_orm(settings) + + qry = select(t).where(t.c.COUNTRY_s == bindparam("COUNTRY_s")).limit(10) + + with engine.connect() as connection: + result_1 = connection.execute(qry, {"COUNTRY_s": "Sweden"}) + result_2 = connection.execute(qry, {"COUNTRY_s": "France"}) + + assert result_1.context.cache_hit == _symbol("NO_DIALECT_SUPPORT") + assert result_2.context.cache_hit == _symbol("NO_DIALECT_SUPPORT") diff --git a/tests/test_sql_compilation_caching_9.py b/tests/test_sql_compilation_caching_9.py new file mode 100644 index 0000000..9303fc1 --- /dev/null +++ b/tests/test_sql_compilation_caching_9.py @@ -0,0 +1,179 @@ +from sqlalchemy import and_ +from sqlalchemy import select +from sqlalchemy.sql.expression import bindparam +from sqlalchemy.util.langhelpers import _symbol +from tests import assertions +from tests.setup import prepare_orm + +releases = [9] + + +class TestSQLCompilationCaching: + + def test_sql_compilation_caching_1(self, settings): + assertions.assert_solr_release(settings, releases) + + _, t = prepare_orm(settings) + + qry_1 = (select(t.c.CITY_s).select_from(t)).limit(1) + qry_2 = (select(t.c.CITY_s).select_from(t)).limit(10) + + k1 = qry_1._generate_cache_key() # pylint: disable=protected-access + k2 = qry_2._generate_cache_key() # pylint: disable=protected-access + + assert k1 == k2 + + def test_sql_compilation_caching_2(self, settings): + assertions.assert_solr_release(settings, releases) + + _, t = prepare_orm(settings) + + qry_1 = (select(t.c.CITY_s).select_from(t)).limit(1).offset(1) + qry_2 = (select(t.c.CITY_s).select_from(t)).limit(1).offset(2) + + k1 = qry_1._generate_cache_key() # pylint: disable=protected-access + k2 = qry_2._generate_cache_key() # pylint: disable=protected-access + + assert k1 == k2 + + def test_sql_compilation_caching_3(self, settings): + assertions.assert_solr_release(settings, releases) + + engine, t = prepare_orm(settings) + + qry = select(t).where(t.c.CITY_s == bindparam("CITY_s")).limit(10) + + with engine.connect() as connection: + result_1 = connection.execute(qry, {"CITY_s": "Singapore"}) + result_2 = connection.execute(qry, {"CITY_s": "Boras"}) + + assert result_1.context.cache_hit == _symbol("CACHE_MISS") + assert result_2.context.cache_hit == _symbol("CACHE_HIT") + + def test_sql_compilation_caching_4(self, settings): + assertions.assert_solr_release(settings, releases) + + _, t = prepare_orm(settings) + + qry_1 = select(t).where(t.c.CITY_s == bindparam("CITY_s")).limit(10) + qry_2 = select(t).where(t.c.COUNTRY_s == bindparam("COUNTRY_s")).limit(10) + + k1 = qry_1._generate_cache_key() # pylint: disable=protected-access + k2 = qry_2._generate_cache_key() # pylint: disable=protected-access + + assert k1 != k2 + + def test_sql_compilation_caching_5(self, settings): + assertions.assert_solr_release(settings, releases) + + engine, t = prepare_orm(settings) + + qry = select(t).where(t.c.ORDERDATE_dt >= bindparam("ORDERDATE_dt")).limit(10) + + with engine.connect() as connection: + result_1 = connection.execute(qry, {"ORDERDATE_dt": "2018-01-01T00:00:00Z"}) + result_2 = connection.execute(qry, {"ORDERDATE_dt": "2018-01-01T00:00:00Z"}) + + assert result_1.context.cache_hit == _symbol("CACHE_MISS") + assert result_2.context.cache_hit == _symbol("CACHE_HIT") + + def test_sql_compilation_caching_6(self, settings): + assertions.assert_solr_release(settings, releases) + + engine, t = prepare_orm(settings) + + qry = ( + select(t.c.ORDERNUMBER_i, t.c.ORDERLINENUMBER_i) + .where(t.c.ORDERDATE_dt >= bindparam("ORDERDATE_dt")) + .order_by(t.c.ORDERDATE_dt.asc()) + .limit(10) + ) + + with engine.connect() as connection: + result_1 = connection.execute(qry, {"ORDERDATE_dt": "2017-05-01T00:00:00Z"}) + result_2 = connection.execute(qry, {"ORDERDATE_dt": "2017-06-01T00:00:00Z"}) + + for row in zip(result_1, result_2, strict=False): + assert row[0][0] != row[1][0] + + assert result_1.context.cache_hit == _symbol("CACHE_MISS") + assert result_2.context.cache_hit == _symbol("CACHE_HIT") + + def test_sql_compilation_caching_7(self, settings): + assertions.assert_solr_release(settings, releases) + + engine, t = prepare_orm(settings) + + qry = ( + select(t) + .where( + and_( + t.c.ORDERDATE_dt + >= bindparam( + "ORDERDATE_dt_1", + t.c.ORDERDATE_dt <= bindparam("ORDERDATE_dt_2"), + ) + ) + ) + .limit(10) + ) + + with engine.connect() as connection: + result_1 = connection.execute( + qry, + { + "ORDERDATE_dt_1": "2017-07-01T00:00:00Z", + "ORDERDATE_dt_2": "2017-08-01T00:00:00Z", + }, + ) + result_2 = connection.execute( + qry, + { + "ORDERDATE_dt_1": "2017-07-01T00:00:00Z", + "ORDERDATE_dt_2": "2017-08-01T00:00:00Z", + }, + ) + + assert result_1.context.cache_hit == _symbol("CACHE_MISS") + assert result_2.context.cache_hit == _symbol("CACHE_HIT") + + def test_sql_compilation_caching_8(self, settings): + assertions.assert_solr_release(settings, releases) + + engine, t = prepare_orm(settings) + + qry = ( + select(t) + .where( + and_( + t.c.ORDERDATE_dt + >= bindparam( + "ORDERDATE_dt_1", + t.c.ORDERDATE_dt <= bindparam("ORDERDATE_dt_2"), + ) + ) + ) + .limit(10) + ) + + with engine.connect() as connection: + result_1 = connection.execute( + qry, + { + "ORDERDATE_dt_1": "2017-09-01T00:00:00Z", + "ORDERDATE_dt_2": "2017-10-01T00:00:00Z", + }, + ) + result_2 = connection.execute( + qry, + { + "ORDERDATE_dt_1": "2017-11-01T00:00:00Z", + "ORDERDATE_dt_2": "2017-12-01T00:00:00Z", + }, + ) + + for row in zip(result_1, result_2, strict=False): + assert row[0][0] != row[1][0] + + assert result_1.context.cache_hit == _symbol("CACHE_MISS") + assert result_2.context.cache_hit == _symbol("CACHE_HIT")