diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index c385ab90..cb62fc01 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -252,7 +252,8 @@ def test_get_default_isolation_level(self): assert isolation_level == "AUTOCOMMIT" def test_isolation_level(self): - dbapi_conn = Connection(host="localhost") + # The test only verifies that isolation level is correctly set, no need to attempt actual connection + dbapi_conn = Connection(host="localhost", defer_connect=True) self.dialect.set_isolation_level(dbapi_conn, "SERIALIZABLE") assert dbapi_conn._isolation_level == IsolationLevel.SERIALIZABLE diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index b56466a2..e0367c86 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -184,7 +184,8 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post conn2.cursor().execute("SELECT 2") conn2.cursor().execute("SELECT 3") - assert len(_post_statement_requests()) == 7 + assert len(_post_statement_requests()) == 9 + # assert only a single token request was sent assert len(_get_token_requests(challenge_id)) == 1 @@ -275,37 +276,38 @@ def test_role_is_set_when_specified(mock_client): def test_hostname_parsing(): - https_server_with_port = Connection("https://mytrinoserver.domain:9999") + # Since this test only verifies URL parsing there is no need to attempt actual connection + https_server_with_port = Connection("https://mytrinoserver.domain:9999", defer_connect=True) assert https_server_with_port.host == "mytrinoserver.domain" assert https_server_with_port.port == 9999 assert https_server_with_port.http_scheme == constants.HTTPS - https_server_without_port = Connection("https://mytrinoserver.domain") + https_server_without_port = Connection("https://mytrinoserver.domain", defer_connect=True) assert https_server_without_port.host == "mytrinoserver.domain" assert https_server_without_port.port == 8080 assert https_server_without_port.http_scheme == constants.HTTPS - http_server_with_port = Connection("http://mytrinoserver.domain:9999") + http_server_with_port = Connection("http://mytrinoserver.domain:9999", defer_connect=True) assert http_server_with_port.host == "mytrinoserver.domain" assert http_server_with_port.port == 9999 assert http_server_with_port.http_scheme == constants.HTTP - http_server_without_port = Connection("http://mytrinoserver.domain") + http_server_without_port = Connection("http://mytrinoserver.domain", defer_connect=True) assert http_server_without_port.host == "mytrinoserver.domain" assert http_server_without_port.port == 8080 assert http_server_without_port.http_scheme == constants.HTTP - http_server_with_path = Connection("http://mytrinoserver.domain/some_path") + http_server_with_path = Connection("http://mytrinoserver.domain/some_path", defer_connect=True) assert http_server_with_path.host == "mytrinoserver.domain/some_path" assert http_server_with_path.port == 8080 assert http_server_with_path.http_scheme == constants.HTTP - only_hostname = Connection("mytrinoserver.domain") + only_hostname = Connection("mytrinoserver.domain", defer_connect=True) assert only_hostname.host == "mytrinoserver.domain" assert only_hostname.port == 8080 assert only_hostname.http_scheme == constants.HTTP - only_hostname_with_path = Connection("mytrinoserver.domain/some_path") + only_hostname_with_path = Connection("mytrinoserver.domain/some_path", defer_connect=True) assert only_hostname_with_path.host == "mytrinoserver.domain/some_path" assert only_hostname_with_path.port == 8080 assert only_hostname_with_path.http_scheme == constants.HTTP diff --git a/trino/dbapi.py b/trino/dbapi.py index 62ce893b..ae1348a3 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -28,6 +28,8 @@ from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types from urllib.parse import urlparse +from requests.exceptions import RequestException + try: from zoneinfo import ZoneInfo except ModuleNotFoundError: @@ -157,6 +159,7 @@ def __init__( legacy_prepared_statements=None, roles=None, timezone=None, + defer_connect=False, ): # Automatically assign http_schema, port based on hostname parsed_host = urlparse(host, allow_fragments=False) @@ -201,6 +204,31 @@ def __init__( self.legacy_primitive_types = legacy_primitive_types self.legacy_prepared_statements = legacy_prepared_statements + if not defer_connect: + self.connect() + + def connect(self) -> None: + connection_test_request = trino.client.TrinoRequest( + self.host, + self.port, + self._client_session, + self._http_session, + self.http_scheme, + self.auth, + self.max_attempts, + self.request_timeout, + verify=self._http_session.verify, + ) + try: + test_response = connection_test_request.post("") + response_content = test_response.content if test_response.content else "" + if not test_response.ok: + raise trino.exceptions.TrinoConnectionError( + "error {}: {}".format(test_response.status_code, response_content)) + + except RequestException as e: + raise trino.exceptions.TrinoConnectionError("connection failed: {}".format(e)) + @property def isolation_level(self): return self._isolation_level