diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a604fb3..672fd065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Deprecated ### Removed ### Fixed +- Fix `Transport.perform_request`'s arguments `timeout` and `ignore` variable usage ([810](https://github.com/opensearch-project/opensearch-py/pull/810)) ### Security ### Dependencies diff --git a/opensearchpy/_async/transport.py b/opensearchpy/_async/transport.py index bc9e257f..2d631ee7 100644 --- a/opensearchpy/_async/transport.py +++ b/opensearchpy/_async/transport.py @@ -371,11 +371,13 @@ async def perform_request( underlying :class:`~opensearchpy.Connection` class for serialization :arg body: body of the request, will be serialized using serializer and passed to the connection + :arg timeout: timeout of the request. If it is not presented as argument + will be extracted from `params` """ await self._async_call() method, params, body, ignore, timeout = self._resolve_request_args( - method, params, body + method, params, body, ignore, timeout ) for attempt in range(self.max_retries + 1): diff --git a/opensearchpy/transport.py b/opensearchpy/transport.py index 2385b2b8..559d0387 100644 --- a/opensearchpy/transport.py +++ b/opensearchpy/transport.py @@ -404,9 +404,11 @@ def perform_request( underlying :class:`~opensearchpy.Connection` class for serialization :arg body: body of the request, will be serialized using serializer and passed to the connection + :arg timeout: timeout of the request. If it is not presented as argument + will be extracted from `params` """ method, params, body, ignore, timeout = self._resolve_request_args( - method, params, body + method, params, body, ignore, timeout ) for attempt in range(self.max_retries + 1): @@ -473,7 +475,14 @@ def close(self) -> Any: """ return self.connection_pool.close() - def _resolve_request_args(self, method: str, params: Any, body: Any) -> Any: + def _resolve_request_args( + self, + method: str, + params: Any, + body: Any, + ignore: Collection[int], + timeout: Optional[Union[int, float]], + ) -> Any: """Resolves parameters for .perform_request()""" if body is not None: body = self.serializer.dumps(body) @@ -498,13 +507,13 @@ def _resolve_request_args(self, method: str, params: Any, body: Any) -> Any: # bytes/str - no need to re-encode pass - ignore = () - timeout = None if params: - timeout = params.pop("request_timeout", None) if not timeout: - timeout = params.pop("timeout", None) - ignore = params.pop("ignore", ()) + timeout = params.pop("request_timeout", None) or params.pop( + "timeout", None + ) + if not ignore: + ignore = params.pop("ignore", ()) if isinstance(ignore, int): ignore = (ignore,) diff --git a/test_opensearchpy/test_async/test_transport.py b/test_opensearchpy/test_async/test_transport.py index e3048a48..51388ddd 100644 --- a/test_opensearchpy/test_async/test_transport.py +++ b/test_opensearchpy/test_async/test_transport.py @@ -175,6 +175,20 @@ async def test_opaque_id(self) -> None: "headers": {"x-opaque-id": "request-1"}, } == t.get_connection().calls[1][1] + async def test_perform_request_all_arguments_passed_correctly(self) -> None: + t: Any = AsyncTransport([{}], connection_class=DummyConnection) + method, url, params, body = "GET", "/test_url", {"params": "test"}, "test_body" + timeout, ignore, headers = 11, ("ignore",), {"h": "test"} + + await t.perform_request(method, url, params, body, timeout, ignore, headers) + + assert t.get_connection().calls == [ + ( + (method, url, params, body.encode()), + {"headers": headers, "ignore": ignore, "timeout": timeout}, + ) + ] + async def test_request_with_custom_user_agent_header(self) -> None: t: Any = AsyncTransport([{}], connection_class=DummyConnection) diff --git a/test_opensearchpy/test_transport.py b/test_opensearchpy/test_transport.py index ccef21df..96422a3f 100644 --- a/test_opensearchpy/test_transport.py +++ b/test_opensearchpy/test_transport.py @@ -260,6 +260,23 @@ def test_add_connection(self) -> None: "http://google.com:1234", t.connection_pool.connections[1].host ) + def test_perform_request_all_arguments_passed_correctly(self) -> None: + t: Any = Transport([{}], connection_class=DummyConnection) + method, url, params, body = "GET", "/test_url", {"params": "test"}, "test_body" + timeout, ignore, headers = 11, ("ignore",), {"h": "test"} + + t.perform_request(method, url, params, body, timeout, ignore, headers) + + self.assertEqual( + t.get_connection().calls, + [ + ( + (method, url, params, body.encode()), + {"headers": headers, "ignore": ignore, "timeout": timeout}, + ) + ], + ) + def test_request_will_fail_after_x_retries(self) -> None: t: Any = Transport( [{"exception": ConnectionError(None, "abandon ship", Exception())}],