Skip to content

Commit

Permalink
Added test for synchronous path and fixed the bug
Browse files Browse the repository at this point in the history
Signed-off-by: Nathalie Jonathan <[email protected]>
  • Loading branch information
nathaliellenaa committed Nov 22, 2024
1 parent da8cb58 commit 9926ff2
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Fixed
- Fix `Transport.perform_request`'s arguments `timeout` and `ignore` variable usage ([810](https://github.com/opensearch-project/opensearch-py/pull/810))
- Fix `Index.save` not passing through aliases to the underlying index ([823](https://github.com/opensearch-project/opensearch-py/pull/823))
- Fix bug where the URL being sent and being signed is different ([848](https://github.com/opensearch-project/opensearch-py/pull/848))
- Fix `AuthorizationException` with AWS OpenSearch when the doc ID contains `:` ([848](https://github.com/opensearch-project/opensearch-py/pull/848))
### Security

### Dependencies
Expand Down
9 changes: 4 additions & 5 deletions opensearchpy/_async/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,24 +384,23 @@ async def perform_request(
method, params, body, ignore, timeout
)

from urllib.parse import unquote

decoded_url = unquote(url)
from urllib.parse import quote

for attempt in range(self.max_retries + 1):
connection = self.get_connection()

try:
status, headers_response, data = await connection.perform_request(
method,
decoded_url,
url,
params,
body,
headers=headers,
ignore=ignore,
timeout=timeout,
)

url = quote(url)

# Lowercase all the header names for consistency in accessing them.
headers_response = {
header.lower(): value for header, value in headers_response.items()
Expand Down
3 changes: 1 addition & 2 deletions opensearchpy/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def _make_path(*parts: Any) -> str:
"""
# TODO: maybe only allow some parts to be lists/tuples ?
return "/" + "/".join(
# preserve ',' and '*' in url for nicer URLs in logs
quote(_escape(p), b",*")
str(p)
for p in parts
if p not in SKIP_IN_PATH
)
Expand Down
7 changes: 1 addition & 6 deletions test_opensearchpy/test_async/test_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import pytest
from _pytest.mark.structures import MarkDecorator

# from urllib.parse import quote, unquote


pytestmark: MarkDecorator = pytest.mark.asyncio


Expand Down Expand Up @@ -134,7 +131,7 @@ async def test_aws_signer_async_consitent_url(self) -> None:
signed_url = None
sent_url = None

doc_id = "test:123"
doc_id = "doc_id:with!special*chars%3A"
url = f"https://search-domain.region.es.amazonaws.com:9200/index/_doc/{doc_id}"

# Create a mock signer class to capture the signed URL
Expand Down Expand Up @@ -177,6 +174,4 @@ async def perform_request(
connection_class=MockConnection,
)
await client.index("index", {"test": "data"}, id=doc_id)

# Verify URLs match
assert signed_url == sent_url, "URLs don't match"
50 changes: 50 additions & 0 deletions test_opensearchpy/test_connection/test_requests_http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,56 @@ def test_aws_signer_signs_with_query_string(self, mock_sign: Any) -> None:
("GET", "http://localhost/?key1=value1&key2=value2", None),
)

def test_aws_signer_consitent_url(self) -> None:
region = "us-west-2"

from opensearchpy import OpenSearch
from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth

# Store URLs for comparison
signed_url = None
sent_url = None

doc_id = "doc_id:with!special*chars%3A"
url = f"https://search-domain.region.es.amazonaws.com:9200/index/_doc/{doc_id}"

# Create a mock signer class to capture the signed URL
class MockSigner(RequestsAWSV4SignerAuth):
def __call__(self, prepared_request):
nonlocal signed_url
if isinstance(prepared_request, str):
signed_url = prepared_request
else:
signed_url = prepared_request.url
return prepared_request

# Create a mock connection class to capture the sent URL
class MockConnection(RequestsHttpConnection):
def perform_request(
self,
method: str,
url: str,
params=None,
body=None,
timeout=None,
ignore=(),
headers=None,
):
nonlocal sent_url
sent_url = f"{self.host}{url}"
return 200, {}, "{}"

auth = MockSigner(self.mock_session(), region)

client = OpenSearch(
hosts=[{"host": "search-domain.region.es.amazonaws.com"}],
http_auth=auth(url),
use_ssl=True,
verify_certs=True,
connection_class=MockConnection,
)
client.index("index", {"test": "data"}, id=doc_id)
assert signed_url == sent_url, "URLs don't match"

class TestRequestsConnectionRedirect(TestCase):
server1: TestHTTPServer
Expand Down

0 comments on commit 9926ff2

Please sign in to comment.