diff --git a/peasant/client/transport.py b/peasant/client/transport.py index cc3fdbf..2958a8f 100644 --- a/peasant/client/transport.py +++ b/peasant/client/transport.py @@ -24,15 +24,15 @@ logger = logging.getLogger(__name__) -def concat_url(url: str, **kwargs: dict) -> str: +def concat_url(url: str, path: str = None, **kwargs: dict) -> str: """ Concatenate a given url to a path, and query string if informed. :param str url: Base url + :param str path: Path to be added to the returned url :param dict kwargs: :key path: Path to be added to the returned url :key query_string: Query string to be added to the returned url """ - path = kwargs.get("path", None) query_string = kwargs.get("query_string", None) if query_string: if isinstance(query_string, dict): diff --git a/peasant/client/transport_requests.py b/peasant/client/transport_requests.py index 60f4b6e..0ab3c74 100644 --- a/peasant/client/transport_requests.py +++ b/peasant/client/transport_requests.py @@ -45,17 +45,12 @@ def __init__(self, bastion_address): raise NotImplementedError self._bastion_address = bastion_address self._directory = None - self.user_agent = (f"Peasant/{get_version()}" + self.user_agent = (f"Peasant/{get_version()} " f"Requests/{requests.__version__}") self.basic_headers = { 'User-Agent': self.user_agent } - def _get_path(self, path, **kwargs): - query_string = kwargs.get('query_string') - if query_string: - path = concat_url(path, query_string=query_string) - def get_headers(self, **kwargs): headers = copy.deepcopy(self.basic_headers) _headers = kwargs.get('headers') @@ -63,11 +58,35 @@ def get_headers(self, **kwargs): headers.update(_headers) return headers - async def get(self, path, **kwargs): - url = concat_url(self._bastion_address, **kwargs) + def head(self, path, **kwargs): + url = concat_url(self._bastion_address, path, **kwargs) + headers = self.get_headers(**kwargs) + kwargs['headers'] = headers + try: + with requests.head(url, **kwargs) as result: + result.raise_for_status() + except requests.HTTPError as error: + raise error + return result + + def get(self, path, **kwargs): + url = concat_url(self._bastion_address, path, **kwargs) headers = self.get_headers(**kwargs) + kwargs['headers'] = headers try: - result = requests.get(url, headers=headers) + result = requests.get(url, **kwargs) + result.raise_for_status() except requests.HTTPError as error: result = error.response return result + + def post(self, path, **kwargs): + url = concat_url(self._bastion_address, path, **kwargs) + headers = self.get_headers(**kwargs) + kwargs['headers'] = headers + try: + with requests.post(url, **kwargs) as result: + result.raise_for_status() + except requests.HTTPError as error: + raise error + return result diff --git a/peasant/client/transport_tornado.py b/peasant/client/transport_tornado.py index 619b55b..ae5409c 100644 --- a/peasant/client/transport_tornado.py +++ b/peasant/client/transport_tornado.py @@ -70,7 +70,7 @@ def __init__(self, bastion_address) -> None: self._client = AsyncHTTPClient() self._bastion_address = fix_address(bastion_address) self._directory = None - self.user_agent = (f"Peasant/{get_version()}" + self.user_agent = (f"Peasant/{get_version()} " f"Tornado/{tornado_version}") self._basic_headers = { 'User-Agent': self.user_agent @@ -83,19 +83,19 @@ def get_headers(self, **kwargs): headers.update(_headers) return headers - async def get(self, **kwargs): - url = concat_url(self._bastion_address, **kwargs) + async def get(self, path: str, **kwargs: dict): + url = concat_url(self._bastion_address, path, **kwargs) request = get_tornado_request(url, **kwargs) headers = self.get_headers(**kwargs) request.headers.update(headers) try: result = await self._client.fetch(request) except HTTPClientError as error: - result = error.response + raise error return result - async def head(self, **kwargs): - url = concat_url(self._bastion_address, **kwargs) + async def head(self, path: str, **kwargs: dict): + url = concat_url(self._bastion_address, path, **kwargs) kwargs["method"] = "HEAD" request = get_tornado_request(url, **kwargs) headers = self.get_headers(**kwargs) @@ -103,11 +103,11 @@ async def head(self, **kwargs): try: result = await self._client.fetch(request) except HTTPClientError as error: - result = error.response + raise error return result - async def post(self, **kwargs): - url = concat_url(self._bastion_address, **kwargs) + async def post(self, path: str, **kwargs: dict): + url = concat_url(self._bastion_address, path, **kwargs) kwargs["method"] = "POST" request = get_tornado_request(url, **kwargs) headers = self.get_headers(**kwargs) @@ -115,5 +115,5 @@ async def post(self, **kwargs): try: result = await self._client.fetch(request) except HTTPClientError as error: - result = error.response + raise error return result diff --git a/tests/runtests.py b/tests/runtests.py index 477b4b5..198aead 100644 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -15,15 +15,16 @@ # limitations under the License. import unittest -from tests import requests_test, tornado_test, transport_test +from tests import (transport_requests_test, transport_test, + transport_tornado_test) def suite(): testLoader = unittest.TestLoader() alltests = unittest.TestSuite() - alltests.addTests(testLoader.loadTestsFromModule(requests_test)) - alltests.addTests(testLoader.loadTestsFromModule(tornado_test)) + alltests.addTests(testLoader.loadTestsFromModule(transport_requests_test)) alltests.addTests(testLoader.loadTestsFromModule(transport_test)) + alltests.addTests(testLoader.loadTestsFromModule(transport_tornado_test)) return alltests diff --git a/tests/requests_test.py b/tests/transport_requests_test.py similarity index 51% rename from tests/requests_test.py rename to tests/transport_requests_test.py index 2deba1f..f53c49c 100644 --- a/tests/requests_test.py +++ b/tests/transport_requests_test.py @@ -29,13 +29,43 @@ def get_launcher(self) -> ProcessLauncher: def setUp(self) -> None: super().setUp() + # setting tests simplefilter to ignore because requests uses a + # keep-alive model, not closing sockets explicitly in many cases. + # with that will cause the ResourceWarning warn be displayed in testing + # as unittests will set warnings.simplefilter to default. + # See: + # - https://github.com/psf/requests/issues/3912#issuecomment-284328247 + # - https://python.readthedocs.io/en/stable/library/warnings.html#updating-code-for-new-versions-of-python + import warnings + warnings.simplefilter("ignore") self.transport = RequestsTransport( f"http://localhost:{self.http_port()}") + @gen_test + async def test_head(self): + try: + response = self.transport.head("/head") + except Exception as e: + raise e + self.assertEqual(response.headers.get("head-response"), + "Head method response") + self.assertEqual(response.headers.get("user-agent"), + self.transport.user_agent) + @gen_test async def test_get(self): try: - response = await self.transport.get("/") + response = self.transport.get("/") except Exception as e: raise e self.assertEqual(response.content, b"Get method output") + + @gen_test + async def test_post(self): + expected_body = "da body" + try: + response = self.transport.post("/post", data="da body") + except Exception as e: + raise e + self.assertEqual(expected_body, response.headers.get("request-body")) + self.assertEqual(response.content, b"Post method output") diff --git a/tests/tornado_test.py b/tests/transport_tornado_test.py similarity index 86% rename from tests/tornado_test.py rename to tests/transport_tornado_test.py index ca69ee6..19c3d01 100644 --- a/tests/tornado_test.py +++ b/tests/transport_tornado_test.py @@ -34,7 +34,7 @@ def setUp(self) -> None: @gen_test async def test_head(self): try: - response = await self.transport.head(path="/head") + response = await self.transport.head("/head") except Exception as e: raise e self.assertEqual(response.headers.get("head-response"), @@ -45,15 +45,17 @@ async def test_head(self): @gen_test async def test_get(self): try: - response = await self.transport.get(path="/") + response = await self.transport.get("/") except Exception as e: raise e self.assertEqual(response.body, b"Get method output") @gen_test async def test_post(self): + expected_body = "da body" try: - response = await self.transport.post(path="/post", body="empty") + response = await self.transport.post("/post", body="da body") except Exception as e: raise e + self.assertEqual(expected_body, response.headers.get("request-body")) self.assertEqual(response.body, b"Post method output")