Skip to content

Commit

Permalink
feat(transport): add head and post methods to requests transport
Browse files Browse the repository at this point in the history
Also renamed tranpost implementation tests files.

Refs: #6 #9
  • Loading branch information
piraz committed Apr 8, 2024
1 parent a6f5896 commit 010c83d
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 28 deletions.
4 changes: 2 additions & 2 deletions peasant/client/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
logger = logging.getLogger(__name__)


def concat_url(url: str, **kwargs: dict) -> str:
def concat_url(url: str, path: str = None, **kwargs: dict) -> str:
""" Concatenate a given url to a path, and query string if informed.
:param str url: Base url
:param str path: Path to be added to the returned url
:param dict kwargs:
:key path: Path to be added to the returned url
:key query_string: Query string to be added to the returned url
"""
path = kwargs.get("path", None)
query_string = kwargs.get("query_string", None)
if query_string:
if isinstance(query_string, dict):
Expand Down
37 changes: 28 additions & 9 deletions peasant/client/transport_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,48 @@ def __init__(self, bastion_address):
raise NotImplementedError
self._bastion_address = bastion_address
self._directory = None
self.user_agent = (f"Peasant/{get_version()}"
self.user_agent = (f"Peasant/{get_version()} "
f"Requests/{requests.__version__}")
self.basic_headers = {
'User-Agent': self.user_agent
}

def _get_path(self, path, **kwargs):
query_string = kwargs.get('query_string')
if query_string:
path = concat_url(path, query_string=query_string)

def get_headers(self, **kwargs):
headers = copy.deepcopy(self.basic_headers)
_headers = kwargs.get('headers')
if _headers:
headers.update(_headers)
return headers

async def get(self, path, **kwargs):
url = concat_url(self._bastion_address, **kwargs)
def head(self, path, **kwargs):
url = concat_url(self._bastion_address, path, **kwargs)
headers = self.get_headers(**kwargs)
kwargs['headers'] = headers
try:
with requests.head(url, **kwargs) as result:
result.raise_for_status()
except requests.HTTPError as error:
raise error
return result

def get(self, path, **kwargs):
url = concat_url(self._bastion_address, path, **kwargs)
headers = self.get_headers(**kwargs)
kwargs['headers'] = headers
try:
result = requests.get(url, headers=headers)
result = requests.get(url, **kwargs)
result.raise_for_status()
except requests.HTTPError as error:
result = error.response
return result

def post(self, path, **kwargs):
url = concat_url(self._bastion_address, path, **kwargs)
headers = self.get_headers(**kwargs)
kwargs['headers'] = headers
try:
with requests.post(url, **kwargs) as result:
result.raise_for_status()
except requests.HTTPError as error:
raise error
return result
20 changes: 10 additions & 10 deletions peasant/client/transport_tornado.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, bastion_address) -> None:
self._client = AsyncHTTPClient()
self._bastion_address = fix_address(bastion_address)
self._directory = None
self.user_agent = (f"Peasant/{get_version()}"
self.user_agent = (f"Peasant/{get_version()} "
f"Tornado/{tornado_version}")
self._basic_headers = {
'User-Agent': self.user_agent
Expand All @@ -83,37 +83,37 @@ def get_headers(self, **kwargs):
headers.update(_headers)
return headers

async def get(self, **kwargs):
url = concat_url(self._bastion_address, **kwargs)
async def get(self, path: str, **kwargs: dict):
url = concat_url(self._bastion_address, path, **kwargs)
request = get_tornado_request(url, **kwargs)
headers = self.get_headers(**kwargs)
request.headers.update(headers)
try:
result = await self._client.fetch(request)
except HTTPClientError as error:
result = error.response
raise error
return result

async def head(self, **kwargs):
url = concat_url(self._bastion_address, **kwargs)
async def head(self, path: str, **kwargs: dict):
url = concat_url(self._bastion_address, path, **kwargs)
kwargs["method"] = "HEAD"
request = get_tornado_request(url, **kwargs)
headers = self.get_headers(**kwargs)
request.headers.update(headers)
try:
result = await self._client.fetch(request)
except HTTPClientError as error:
result = error.response
raise error
return result

async def post(self, **kwargs):
url = concat_url(self._bastion_address, **kwargs)
async def post(self, path: str, **kwargs: dict):
url = concat_url(self._bastion_address, path, **kwargs)
kwargs["method"] = "POST"
request = get_tornado_request(url, **kwargs)
headers = self.get_headers(**kwargs)
request.headers.update(headers)
try:
result = await self._client.fetch(request)
except HTTPClientError as error:
result = error.response
raise error
return result
7 changes: 4 additions & 3 deletions tests/runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
# limitations under the License.

import unittest
from tests import requests_test, tornado_test, transport_test
from tests import (transport_requests_test, transport_test,
transport_tornado_test)


def suite():
testLoader = unittest.TestLoader()
alltests = unittest.TestSuite()
alltests.addTests(testLoader.loadTestsFromModule(requests_test))
alltests.addTests(testLoader.loadTestsFromModule(tornado_test))
alltests.addTests(testLoader.loadTestsFromModule(transport_requests_test))
alltests.addTests(testLoader.loadTestsFromModule(transport_test))
alltests.addTests(testLoader.loadTestsFromModule(transport_tornado_test))
return alltests


Expand Down
32 changes: 31 additions & 1 deletion tests/requests_test.py → tests/transport_requests_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,43 @@ def get_launcher(self) -> ProcessLauncher:

def setUp(self) -> None:
super().setUp()
# setting tests simplefilter to ignore because requests uses a
# keep-alive model, not closing sockets explicitly in many cases.
# with that will cause the ResourceWarning warn be displayed in testing
# as unittests will set warnings.simplefilter to default.
# See:
# - https://github.com/psf/requests/issues/3912#issuecomment-284328247
# - https://python.readthedocs.io/en/stable/library/warnings.html#updating-code-for-new-versions-of-python
import warnings
warnings.simplefilter("ignore")
self.transport = RequestsTransport(
f"http://localhost:{self.http_port()}")

@gen_test
async def test_head(self):
try:
response = self.transport.head("/head")
except Exception as e:
raise e
self.assertEqual(response.headers.get("head-response"),
"Head method response")
self.assertEqual(response.headers.get("user-agent"),
self.transport.user_agent)

@gen_test
async def test_get(self):
try:
response = await self.transport.get("/")
response = self.transport.get("/")
except Exception as e:
raise e
self.assertEqual(response.content, b"Get method output")

@gen_test
async def test_post(self):
expected_body = "da body"
try:
response = self.transport.post("/post", data="da body")
except Exception as e:
raise e
self.assertEqual(expected_body, response.headers.get("request-body"))
self.assertEqual(response.content, b"Post method output")
8 changes: 5 additions & 3 deletions tests/tornado_test.py → tests/transport_tornado_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def setUp(self) -> None:
@gen_test
async def test_head(self):
try:
response = await self.transport.head(path="/head")
response = await self.transport.head("/head")
except Exception as e:
raise e
self.assertEqual(response.headers.get("head-response"),
Expand All @@ -45,15 +45,17 @@ async def test_head(self):
@gen_test
async def test_get(self):
try:
response = await self.transport.get(path="/")
response = await self.transport.get("/")
except Exception as e:
raise e
self.assertEqual(response.body, b"Get method output")

@gen_test
async def test_post(self):
expected_body = "da body"
try:
response = await self.transport.post(path="/post", body="empty")
response = await self.transport.post("/post", body="da body")
except Exception as e:
raise e
self.assertEqual(expected_body, response.headers.get("request-body"))
self.assertEqual(response.body, b"Post method output")

0 comments on commit 010c83d

Please sign in to comment.