Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add an option to deferred fetch result in Cursor.execute() #400

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import pytest

from tests.unit.oauth_test_utils import SERVER_ADDRESS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we put SERVER_ADDRESS in a more central place if we want to use it in multiple places. here the usage is not related to oauth.


@pytest.fixture(scope="session")

@pytest.fixture
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change fixture scope to function to fix issue where it's modified between test cases.

def sample_post_response_data():
"""
This is the response to the first HTTP request (a POST) from an actual
Expand All @@ -38,10 +40,10 @@ def sample_post_response_data():
"""

yield {
"nextUri": "https://coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1",
"nextUri": f"{SERVER_ADDRESS}:8080/v1/statement/20210817_140827_00000_arvdv/1",
"id": "20210817_140827_00000_arvdv",
"taskDownloadUris": [],
"infoUri": "https://coordinator:8080/query.html?20210817_140827_00000_arvdv",
"infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv",
"stats": {
"scheduled": False,
"runningSplits": 0,
Expand All @@ -60,7 +62,7 @@ def sample_post_response_data():
}


@pytest.fixture(scope="session")
@pytest.fixture
def sample_get_response_data():
"""
This is the response to the second HTTP request (a GET) from an actual
Expand All @@ -73,7 +75,7 @@ def sample_get_response_data():
"""
yield {
"id": "20210817_140827_00000_arvdv",
"nextUri": "coordinator:8080/v1/statement/20210817_140827_00000_arvdv/2",
"nextUri": f"{SERVER_ADDRESS}:8080/v1/statement/20210817_140827_00000_arvdv/2",
"data": [
["UUID-0", "http://worker0:8080", "0.157", False, "active"],
["UUID-1", "http://worker1:8080", "0.157", False, "active"],
Expand Down Expand Up @@ -132,7 +134,7 @@ def sample_get_response_data():
},
],
"taskDownloadUris": [],
"partialCancelUri": "http://localhost:8080/v1/stage/20210817_140827_00000_arvdv.0", # NOQA: E501
"partialCancelUri": f"{SERVER_ADDRESS}:8080/v1/stage/20210817_140827_00000_arvdv.0", # NOQA: E501
"stats": {
"nodes": 2,
"processedBytes": 880,
Expand Down Expand Up @@ -181,11 +183,11 @@ def sample_get_response_data():
"queuedSplits": 0,
"wallTimeMillis": 36,
},
"infoUri": "http://coordinator:8080/query.html?20210817_140827_00000_arvdv", # NOQA: E501
"infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", # NOQA: E501
}


@pytest.fixture(scope="session")
@pytest.fixture
def sample_get_error_response_data():
yield {
"error": {
Expand All @@ -195,8 +197,7 @@ def sample_get_error_response_data():
"errorType": "USER_ERROR",
"failureInfo": {
"errorLocation": {"columnNumber": 15, "lineNumber": 1},
"message": "line 1:15: Schema must be specified "
"when session schema is not set",
"message": "line 1:15: Schema must be specified when session schema is not set",
"stack": [
"io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:48)",
"io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:43)",
Expand Down Expand Up @@ -241,7 +242,7 @@ def sample_get_error_response_data():
"message": "line 1:15: Schema must be specified when session schema is not set",
},
"id": "20210817_140827_00000_arvdv",
"infoUri": "http://localhost:8080/query.html?20210817_140827_00000_arvdv",
"infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv",
"stats": {
"completedSplits": 0,
"cpuTimeMillis": 0,
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,57 @@ def json(self):
assert isinstance(result, TrinoResult)


def test_trino_query_deferred_fetch(sample_get_response_data):
"""
Validates that the `TrinoQuery.execute` function deferred_fetch and non-block execution
"""

class MockResponse(mock.Mock):
Copy link

@giftig giftig Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is copied and pasted from another function, so it would be nice to extract it out and stay DRY.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's depends on sample_get_response_data local variable.
Still have a way to extract it out, but I wonder is it generic enough for other tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion is keep it as is. if there are more tests use it, we could extract it later, in other PR.

Copy link

@giftig giftig Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's something which could just be passed into the constructor though, right? There's no reason for it to refer to local scope. It's a whole class which has been copied and pasted and defined in a local function, I think now is the right time to refactor it and stay DRY personally.

# Fake response class
@property
def headers(self):
return {
'X-Trino-Fake-1': 'one',
'X-Trino-Fake-2': 'two',
}

def json(self):
return sample_get_response_data

rows = sample_get_response_data['data']
sample_get_response_data['data'] = []
sql = 'SELECT 1'
request = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(
user="test",
source="test",
catalog="test",
schema="test",
properties={},
),
http_scheme="http",
)
query = TrinoQuery(
dungdm93 marked this conversation as resolved.
Show resolved Hide resolved
request=request,
query=sql
)

with \
mock.patch.object(request, 'post', return_value=MockResponse()), \
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch:
result = query.execute()
mock_fetch.assert_called_once()
assert result.rows == rows

with \
mock.patch.object(request, 'post', return_value=MockResponse()), \
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch:
result = query.execute(deferred_fetch=True)
mock_fetch.assert_not_called()


def test_delay_exponential_without_jitter():
max_delay = 1200.0
get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay)
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from unittest.mock import patch

import httpretty
from httpretty import httprettified
from requests import Session

from tests.unit.oauth_test_utils import (
Expand Down Expand Up @@ -58,7 +57,7 @@ def test_http_session_is_defaulted_when_not_specified(mock_client):
assert mock_client.TrinoRequest.http.Session.return_value in request_args


@httprettified
@httpretty.activate
def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sample_get_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())
Expand All @@ -73,13 +72,15 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)
body=post_statement_callback
)

# bind get statement for result retrieval
httpretty.register_uri(
method=httpretty.GET,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
body=get_statement_callback)
body=get_statement_callback
)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
Expand Down Expand Up @@ -122,7 +123,7 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl
assert len(_get_token_requests(challenge_id)) == 2


@httprettified
@httpretty.activate
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data,
sample_get_response_data):
token = str(uuid.uuid4())
Expand Down Expand Up @@ -188,7 +189,7 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post
assert len(_get_token_requests(challenge_id)) == 1


@httprettified
@httpretty.activate
def test_token_retrieved_once_when_multithreaded(sample_post_response_data, sample_get_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())
Expand Down
2 changes: 1 addition & 1 deletion trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _determine_host(url: Optional[str]) -> Any:


class OAuth2Authentication(Authentication):
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([
def __init__(self, redirect_auth_url_handler: RedirectHandler = CompositeRedirectHandler([
WebBrowserRedirectHandler(),
ConsoleRedirectHandler()
])):
Expand Down
36 changes: 21 additions & 15 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,11 +688,11 @@ def __init__(self, query, rows: List[Any]):
self._rownumber = 0

@property
def rows(self):
def rows(self) -> List[Any]:
return self._rows

@rows.setter
def rows(self, rows):
def rows(self, rows: List[Any]):
self._rows = rows

@property
Expand All @@ -702,14 +702,13 @@ def rownumber(self) -> int:
def __iter__(self):
# A query only transitions to a FINISHED state when the results are fully consumed:
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
while not self._query.finished or self._rows is not None:
next_rows = self._query.fetch() if not self._query.finished else None
while not self._query.finished or self._rows:
for row in self._rows:
self._rownumber += 1
logger.debug("row %s", row)
yield row

self._rows = next_rows
self._rows = self._query.fetch() if not self._query.finished else []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Through this change the above comment is not valid anymore as the next_uri is only acknowledged after exposing all rows from self._rows through dbapi. The result is the same but now the query will be hanging until the dbapi method fetch_one or fetch_many has processed all rows in the buffer.

Also in the future the fetch could be made async, done in a separate thread, which allows for continuously looping through rows and fetching the next resultset.



class TrinoQuery(object):
Expand Down Expand Up @@ -778,13 +777,18 @@ def result(self):
def info_uri(self):
return self._info_uri

def execute(self, additional_http_headers=None) -> TrinoResult:
"""Initiate a Trino query by sending the SQL statement

This is the first HTTP request sent to the coordinator.
It sets the query_id and returns a Result object used to
track the rows returned by the query. To fetch all rows,
call fetch() until finished is true.
def execute(
self,
additional_http_headers: Optional[Dict[str, Any]] = None,
deferred_fetch: bool = False,
) -> TrinoResult:
"""Initiate a Trino query by sending the SQL statement to the coordinator.
To fetch all rows, call fetch() until finished is true.

Parameters:
additional_http_headers: extra headers send to the Trino server.
deferred_fetch: By default, the execution is blocked until at least one row is received
or query is finished or cancelled. To continue without waiting the result, set deferred_fetch=True.
"""
if self.cancelled:
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id)
Expand All @@ -805,9 +809,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
self._result = TrinoResult(self, rows)

# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
if not deferred_fetch:
# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()

return self._result

def _update_state(self, status):
Expand Down
20 changes: 16 additions & 4 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
def _generate_unique_statement_name(self):
return 'st_' + uuid.uuid4().hex.replace('-', '')

def execute(self, operation, params=None):
def execute(self, operation, params=None, **kwargs: Any):
additional_http_headers = kwargs.get("additional_http_headers", None)
deferred_fetch = kwargs.get("deferred_fetch", False)

if params:
assert isinstance(params, (list, tuple)), (
'params must be a list or tuple containing the query '
Expand All @@ -575,7 +578,10 @@ def execute(self, operation, params=None):
self._query = self._execute_prepared_statement(
statement_name, params
)
self._iterator = iter(self._query.execute())
self._iterator = iter(self._query.execute(
additional_http_headers=additional_http_headers,
deferred_fetch=deferred_fetch,
))
finally:
# Send deallocate statement
# At this point the query can be deallocated since it has already
Expand All @@ -584,12 +590,18 @@ def execute(self, operation, params=None):
self._deallocate_prepared_statement(statement_name)
else:
self._query = self._execute_immediate_statement(operation, params)
self._iterator = iter(self._query.execute())
self._iterator = iter(self._query.execute(
additional_http_headers=additional_http_headers,
deferred_fetch=deferred_fetch,
))

else:
self._query = trino.client.TrinoQuery(self._request, query=operation,
legacy_primitive_types=self._legacy_primitive_types)
self._iterator = iter(self._query.execute())
self._iterator = iter(self._query.execute(
additional_http_headers=additional_http_headers,
deferred_fetch=deferred_fetch,
))
return self

def executemany(self, operation, seq_of_params):
Expand Down
3 changes: 2 additions & 1 deletion trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
def do_execute(
self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None
):
cursor.execute(statement, parameters)
execution_options = (context.execution_options or {}) if context else {}
cursor.execute(statement, parameters, **execution_options)

def do_rollback(self, dbapi_connection: trino_dbapi.Connection):
if dbapi_connection.transaction is not None:
Expand Down