Skip to content

Commit

Permalink
added tests for using host header for AWS request signature on both s…
Browse files Browse the repository at this point in the history
…ync and async clients
  • Loading branch information
brunomurino committed Nov 30, 2024
1 parent 571741c commit c9dae0e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 13 deletions.
18 changes: 9 additions & 9 deletions opensearchpy/helpers/asyncsigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _sign_request(
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

signature_host = self._fetch_url(url, headers or dict()) # type: ignore
signature_host = self._fetch_url(url, headers or dict())

# create an AWS request object and sign it using SigV4Auth
aws_request = AWSRequest(
Expand Down Expand Up @@ -86,25 +86,25 @@ def _sign_request(
# copy the headers from AWS request object into the prepared_request
return dict(aws_request.headers.items())

def _fetch_url(self, url, headers): # type: ignore
def _fetch_url(self, url: str, headers: Optional[Dict[str, str]]) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
url = urlparse(url)
path = url.path or "/"
parsed_url = urlparse(url)
path = parsed_url.path or "/"

# fetch the query string if present in the request
querystring = ""
if url.query:
if parsed_url.query:
querystring = "?" + urlencode(
parse_qs(url.query, keep_blank_values=True), doseq=True
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
)

# fetch the host information from headers
headers = {key.lower(): value for key, value in headers.items()}
location = headers.get("host") or url.netloc
headers = {key.lower(): value for key, value in (headers or dict()).items()}
location = headers.get("host") or parsed_url.netloc

# construct the url and return
return url.scheme + "://" + location + path + querystring
return parsed_url.scheme + "://" + location + path + querystring
8 changes: 4 additions & 4 deletions opensearchpy/helpers/signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ def _sign_request(self, prepared_request): # type: ignore
prepared_request.headers.update(
self.signer.sign(
prepared_request.method,
self._fetch_url(prepared_request), # type: ignore
self._fetch_url(prepared_request),
prepared_request.body,
)
)

return prepared_request

def _fetch_url(self, prepared_request): # type: ignore
def _fetch_url(self, prepared_request: requests.PreparedRequest) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
Expand All @@ -112,7 +112,7 @@ def _fetch_url(self, prepared_request): # type: ignore
querystring = ""
if url.query:
querystring = "?" + urlencode(
parse_qs(url.query, keep_blank_values=True), doseq=True
parse_qs(url.query, keep_blank_values=True), doseq=True # type: ignore
)

# fetch the host information from headers
Expand All @@ -122,7 +122,7 @@ def _fetch_url(self, prepared_request): # type: ignore
location = headers.get("host") or url.netloc

# construct the url and return
return url.scheme + "://" + location + path + querystring
return url.scheme + "://" + location + path + querystring # type: ignore


# Deprecated: use RequestsAWSV4SignerAuth
Expand Down
14 changes: 14 additions & 0 deletions test_opensearchpy/test_async/test_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ async def test_aws_signer_async_when_service_is_specified(self) -> None:
assert "X-Amz-Security-Token" in headers
assert "X-Amz-Content-SHA256" in headers

async def test_aws_signer_async_fetch_url_with_querystring(self) -> None:
region = "us-west-2"
service = "aoss"

from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth

auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)

signature_host = auth._fetch_url(
"http://localhost/?foo=bar", headers={"host": "otherhost"}
)

assert signature_host == "http://otherhost/?foo=bar"


class TestAsyncSignerWithFrozenCredentials(TestAsyncSigner):
def mock_session(self, disable_get_frozen: bool = True) -> Mock:
Expand Down
17 changes: 17 additions & 0 deletions test_opensearchpy/test_connection/test_requests_http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,23 @@ def mock_session(self) -> Any:

return dummy_session

def test_aws_signer_fetch_url_with_querystring(self) -> None:
region = "us-west-2"

import requests

from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth

auth = RequestsAWSV4SignerAuth(self.mock_session(), region)

prepared_request = requests.Request(
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
).prepare()

signature_host = auth._fetch_url(prepared_request)

assert signature_host == "http://otherhost:443/?foo=bar"

def test_aws_signer_as_http_auth(self) -> None:
region = "us-west-2"

Expand Down

0 comments on commit c9dae0e

Please sign in to comment.