Skip to content

Commit

Permalink
[7.13] Tolerate RecursionError not being defined in Python<3.5
Browse files Browse the repository at this point in the history
Co-authored-by: Seth Michael Larson <[email protected]>
  • Loading branch information
github-actions[bot] and sethmlarson authored Jun 29, 2021
1 parent e71001d commit ba8f295
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 7 deletions.
4 changes: 2 additions & 2 deletions elasticsearch/_async/http_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import urllib3 # type: ignore

from ..compat import urlencode
from ..compat import reraise_exceptions, urlencode
from ..connection.base import Connection
from ..exceptions import (
ConnectionError,
Expand Down Expand Up @@ -304,7 +304,7 @@ async def perform_request(
duration = self.loop.time() - start

# We want to reraise a cancellation or recursion error.
except (asyncio.CancelledError, RecursionError):
except reraise_exceptions:
raise
except Exception as e:
self.log_request_fail(
Expand Down
14 changes: 14 additions & 0 deletions elasticsearch/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,22 @@ def to_bytes(x, encoding="ascii"):
from collections import Mapping


try:
reraise_exceptions = (RecursionError,)
except NameError:
reraise_exceptions = ()

try:
import asyncio

reraise_exceptions += (asyncio.CancelledError,)
except (ImportError, AttributeError):
pass


__all__ = [
"string_types",
"reraise_exceptions",
"quote_plus",
"quote",
"urlencode",
Expand Down
3 changes: 2 additions & 1 deletion elasticsearch/compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.

import sys
from typing import Callable, Tuple, Union
from typing import Callable, Tuple, Type, Union

PY2: bool
string_types: Tuple[type, ...]

to_str: Callable[[Union[str, bytes]], str]
to_bytes: Callable[[Union[str, bytes]], bytes]
reraise_exceptions: Tuple[Type[Exception], ...]

if sys.version_info[0] == 2:
from itertools import imap as map
Expand Down
4 changes: 2 additions & 2 deletions elasticsearch/connection/http_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import time
import warnings

from ..compat import string_types, urlencode
from ..compat import reraise_exceptions, string_types, urlencode
from ..exceptions import (
ConnectionError,
ConnectionTimeout,
Expand Down Expand Up @@ -166,7 +166,7 @@ def perform_request(
response = self.session.send(prepared_request, **send_kwargs)
duration = time.time() - start
raw_data = response.content.decode("utf-8", "surrogatepass")
except RecursionError:
except reraise_exceptions:
raise
except Exception as e:
self.log_request_fail(
Expand Down
4 changes: 2 additions & 2 deletions elasticsearch/connection/http_urllib3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from urllib3.exceptions import SSLError as UrllibSSLError # type: ignore
from urllib3.util.retry import Retry # type: ignore

from ..compat import urlencode
from ..compat import reraise_exceptions, urlencode
from ..exceptions import (
ConnectionError,
ConnectionTimeout,
Expand Down Expand Up @@ -253,7 +253,7 @@ def perform_request(
)
duration = time.time() - start
raw_data = response.data.decode("utf-8", "surrogatepass")
except RecursionError:
except reraise_exceptions:
raise
except Exception as e:
self.log_request_fail(
Expand Down
21 changes: 21 additions & 0 deletions test_elasticsearch/test_async/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from multidict import CIMultiDict

from elasticsearch import AIOHttpConnection, __versionstr__
from elasticsearch.compat import reraise_exceptions
from elasticsearch.exceptions import ConnectionError

pytestmark = pytest.mark.asyncio

Expand Down Expand Up @@ -318,6 +320,20 @@ async def test_surrogatepass_into_bytes(self):
status, headers, data = await con.perform_request("GET", "/")
assert u"你好\uda6a" == data

@pytest.mark.parametrize("exception_cls", reraise_exceptions)
async def test_recursion_error_reraised(self, exception_cls):
conn = AIOHttpConnection()

def request_raise(*_, **__):
raise exception_cls("Wasn't modified!")

await conn._create_aiohttp_session()
conn.session.request = request_raise

with pytest.raises(exception_cls) as e:
await conn.perform_request("GET", "/")
assert str(e.value) == "Wasn't modified!"


class TestConnectionHttpbin:
"""Tests the HTTP connection implementations against a live server E2E"""
Expand Down Expand Up @@ -389,3 +405,8 @@ async def test_aiohttp_connection(self):
"Header2": "value2",
"User-Agent": user_agent,
}

async def test_aiohttp_connection_error(self):
conn = AIOHttpConnection("not.a.host.name")
with pytest.raises(ConnectionError):
await conn.perform_request("GET", "/")
42 changes: 42 additions & 0 deletions test_elasticsearch/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
from urllib3._collections import HTTPHeaderDict

from elasticsearch import __versionstr__
from elasticsearch.compat import reraise_exceptions
from elasticsearch.connection import (
Connection,
RequestsHttpConnection,
Urllib3HttpConnection,
)
from elasticsearch.exceptions import (
ConflictError,
ConnectionError,
NotFoundError,
RequestError,
TransportError,
Expand Down Expand Up @@ -466,6 +468,21 @@ def test_surrogatepass_into_bytes(self):
status, headers, data = con.perform_request("GET", "/")
self.assertEqual(u"你好\uda6a", data)

@pytest.mark.skipif(
not reraise_exceptions, reason="RecursionError isn't defined in Python <3.5"
)
def test_recursion_error_reraised(self):
conn = Urllib3HttpConnection()

def urlopen_raise(*_, **__):
raise RecursionError("Wasn't modified!")

conn.pool.urlopen = urlopen_raise

with pytest.raises(RecursionError) as e:
conn.perform_request("GET", "/")
assert str(e.value) == "Wasn't modified!"


class TestRequestsConnection(TestCase):
def _get_mock_connection(
Expand Down Expand Up @@ -868,6 +885,21 @@ def test_surrogatepass_into_bytes(self):
status, headers, data = con.perform_request("GET", "/")
self.assertEqual(u"你好\uda6a", data)

@pytest.mark.skipif(
not reraise_exceptions, reason="RecursionError isn't defined in Python <3.5"
)
def test_recursion_error_reraised(self):
conn = RequestsHttpConnection()

def send_raise(*_, **__):
raise RecursionError("Wasn't modified!")

conn.session.send = send_raise

with pytest.raises(RecursionError) as e:
conn.perform_request("GET", "/")
assert str(e.value) == "Wasn't modified!"


class TestConnectionHttpbin:
"""Tests the HTTP connection implementations against a live server E2E"""
Expand Down Expand Up @@ -942,6 +974,11 @@ def test_urllib3_connection(self):
"User-Agent": user_agent,
}

def test_urllib3_connection_error(self):
conn = Urllib3HttpConnection("not.a.host.name")
with pytest.raises(ConnectionError):
conn.perform_request("GET", "/")

def test_requests_connection(self):
# Defaults
conn = RequestsHttpConnection("httpbin.org", port=443, use_ssl=True)
Expand Down Expand Up @@ -1003,3 +1040,8 @@ def test_requests_connection(self):
"Header2": "value2",
"User-Agent": user_agent,
}

def test_requests_connection_error(self):
conn = RequestsHttpConnection("not.a.host.name")
with pytest.raises(ConnectionError):
conn.perform_request("GET", "/")

0 comments on commit ba8f295

Please sign in to comment.